|
@@ -47,6 +47,11 @@ class PG_VectorStore(VannaBase):
|
|
|
collection_name="documentation",
|
|
|
connection=self.connection_string,
|
|
|
)
|
|
|
+ self.error_sql_collection = PGVector(
|
|
|
+ embeddings=self.embedding_function,
|
|
|
+ collection_name="error_sql",
|
|
|
+ connection=self.connection_string,
|
|
|
+ )
|
|
|
|
|
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
|
question_sql_json = json.dumps(
|
|
@@ -92,6 +97,8 @@ class PG_VectorStore(VannaBase):
|
|
|
return self.ddl_collection
|
|
|
case "documentation":
|
|
|
return self.documentation_collection
|
|
|
+ case "error_sql":
|
|
|
+ return self.error_sql_collection
|
|
|
case _:
|
|
|
raise ValueError("Specified collection does not exist.")
|
|
|
|
|
@@ -300,9 +307,22 @@ class PG_VectorStore(VannaBase):
|
|
|
for _, row in df_embedding.iterrows():
|
|
|
custom_id = row["cmetadata"]["id"]
|
|
|
document = row["document"]
|
|
|
- training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
|
|
|
+
|
|
|
+ # 处理不同类型的ID后缀
|
|
|
+ if custom_id.endswith("-doc"):
|
|
|
+ training_data_type = "documentation"
|
|
|
+ elif custom_id.endswith("-error_sql"):
|
|
|
+ training_data_type = "error_sql"
|
|
|
+ elif custom_id.endswith("-sql"):
|
|
|
+ training_data_type = "sql"
|
|
|
+ elif custom_id.endswith("-ddl"):
|
|
|
+ training_data_type = "ddl"
|
|
|
+ else:
|
|
|
+ # If the suffix is not recognized, skip this row
|
|
|
+ logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
|
|
|
+ continue
|
|
|
|
|
|
- if training_data_type == "sql":
|
|
|
+ if training_data_type in ["sql", "error_sql"]:
|
|
|
# Convert the document string to a dictionary
|
|
|
try:
|
|
|
doc_dict = ast.literal_eval(document)
|
|
@@ -314,10 +334,6 @@ class PG_VectorStore(VannaBase):
|
|
|
elif training_data_type in ["documentation", "ddl"]:
|
|
|
question = None # Default value for question
|
|
|
content = document
|
|
|
- else:
|
|
|
- # If the suffix is not recognized, skip this row
|
|
|
- logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
|
|
|
- continue
|
|
|
|
|
|
# Append the processed data to the list
|
|
|
processed_rows.append(
|
|
@@ -361,11 +377,11 @@ class PG_VectorStore(VannaBase):
|
|
|
engine = create_engine(self.connection_string)
|
|
|
|
|
|
# Determine the suffix to look for based on the collection name
|
|
|
- suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
|
|
|
+ suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc", "error_sql": "error_sql"}
|
|
|
suffix = suffix_map.get(collection_name)
|
|
|
|
|
|
if not suffix:
|
|
|
- logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
|
|
|
+ logging.info("Invalid collection name. Choose from 'ddl', 'sql', 'documentation', or 'error_sql'.")
|
|
|
return False
|
|
|
|
|
|
# SQL query to delete rows based on the condition
|
|
@@ -397,4 +413,84 @@ class PG_VectorStore(VannaBase):
|
|
|
return False
|
|
|
|
|
|
def generate_embedding(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 增加错误SQL的训练和查询功能
|
|
|
+ # 1. 确保error_sql集合存在
|
|
|
+ def _ensure_error_sql_collection(self):
|
|
|
+ """确保error_sql集合存在"""
|
|
|
+ # 集合已在 __init__ 中初始化,这里只是为了保持方法的一致性
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 2. 将错误的question-sql对存储到error_sql集合中
|
|
|
+ def train_error_sql(self, question: str, sql: str, **kwargs) -> str:
|
|
|
+ """
|
|
|
+ 将错误的question-sql对存储到error_sql集合中
|
|
|
+ """
|
|
|
+ # 确保集合存在
|
|
|
+ self._ensure_error_sql_collection()
|
|
|
+
|
|
|
+ # 创建文档内容,格式与现有SQL训练数据一致
|
|
|
+ question_sql_json = json.dumps(
|
|
|
+ {
|
|
|
+ "question": question,
|
|
|
+ "sql": sql,
|
|
|
+ "type": "error_sql"
|
|
|
+ },
|
|
|
+ ensure_ascii=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 生成ID,使用与现有方法一致的格式
|
|
|
+ id = str(uuid.uuid4()) + "-error_sql"
|
|
|
+ createdat = kwargs.get("createdat")
|
|
|
+
|
|
|
+ # 创建Document对象
|
|
|
+ doc = Document(
|
|
|
+ page_content=question_sql_json,
|
|
|
+ metadata={"id": id, "createdat": createdat},
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加到error_sql集合
|
|
|
+ self.error_sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
|
|
|
+
|
|
|
+ return id
|
|
|
+
|
|
|
+ # 3. 获取相似的错误SQL示例
|
|
|
+ def get_error_sql_examples(self, question: str, limit: int = 5) -> list:
|
|
|
+ """
|
|
|
+ 获取相似的错误SQL示例
|
|
|
+ """
|
|
|
+ # 确保集合存在
|
|
|
+ self._ensure_error_sql_collection()
|
|
|
+
|
|
|
+ try:
|
|
|
+ docs_with_scores = self.error_sql_collection.similarity_search_with_score(
|
|
|
+ query=question,
|
|
|
+ k=limit
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for doc, score in docs_with_scores:
|
|
|
+ try:
|
|
|
+ # 将文档内容转换为 dict,与现有方法保持一致
|
|
|
+ base = ast.literal_eval(doc.page_content)
|
|
|
+
|
|
|
+ # 计算相似度
|
|
|
+ similarity = round(1 - score, 4)
|
|
|
+
|
|
|
+ # 每条记录单独打印
|
|
|
+ print(f"[DEBUG] Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
|
|
|
+
|
|
|
+ # 添加 similarity 字段
|
|
|
+ base["similarity"] = similarity
|
|
|
+ results.append(base)
|
|
|
+
|
|
|
+ except (ValueError, SyntaxError) as e:
|
|
|
+ print(f"Error parsing error SQL document: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error retrieving error SQL examples: {e}")
|
|
|
+ return []
|