Bladeren bron

增加了一个存储error_sql的API.

wangxq 3 weken geleden
bovenliggende
commit
1f07ec6c8b
2 gewijzigde bestanden met toevoegingen van 150 en 9 verwijderingen
  1. 45 0
      citu_app.py
  2. 105 9
      custompgvector/pgvector.py

+ 45 - 0
citu_app.py

@@ -904,6 +904,51 @@ def cache_cleanup():
             message=f"清理缓存失败: {str(e)}", 
             code=500
         )), 500
+    
+
+@app.flask_app.route('/api/v0/training_error_question_sql', methods=['POST'])
+def training_error_question_sql():
+    """
+    存储错误的question-sql对到error_sql集合中
+    
+    此API将接收的错误question/sql pair写入到error_sql集合中,用于记录和分析错误的SQL查询。
+    
+    Args:
+        question (str, required): 用户问题
+        sql (str, required): 对应的错误SQL查询语句
+    
+    Returns:
+        JSON: 包含训练ID和成功消息的响应
+    """
+    try:
+        data = request.get_json()
+        question = data.get('question')
+        sql = data.get('sql')
+        
+        print(f"[DEBUG] 接收到错误SQL训练请求: question={question}, sql={sql}")
+        
+        if not question or not sql:
+            return jsonify(result.failed(
+                message="question和sql参数都是必需的", 
+                code=400
+            )), 400
+        
+        # 使用vn实例的train_error_sql方法存储错误SQL
+        id = vn.train_error_sql(question=question, sql=sql)
+        
+        print(f"[INFO] 成功存储错误SQL,ID: {id}")
+        
+        return jsonify(result.success(data={
+            "id": id,
+            "message": "错误SQL对已成功存储到error_sql集合"
+        }))
+        
+    except Exception as e:
+        print(f"[ERROR] 存储错误SQL失败: {str(e)}")
+        return jsonify(result.failed(
+            message=f"存储错误SQL失败: {str(e)}", 
+            code=500
+        )), 500
 
 
 

+ 105 - 9
custompgvector/pgvector.py

@@ -47,6 +47,11 @@ class PG_VectorStore(VannaBase):
             collection_name="documentation",
             connection=self.connection_string,
         )
+        self.error_sql_collection = PGVector(
+            embeddings=self.embedding_function,
+            collection_name="error_sql",
+            connection=self.connection_string,
+        )
 
     def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
         question_sql_json = json.dumps(
@@ -92,6 +97,8 @@ class PG_VectorStore(VannaBase):
                 return self.ddl_collection
             case "documentation":
                 return self.documentation_collection
+            case "error_sql":
+                return self.error_sql_collection
             case _:
                 raise ValueError("Specified collection does not exist.")
 
@@ -300,9 +307,22 @@ class PG_VectorStore(VannaBase):
         for _, row in df_embedding.iterrows():
             custom_id = row["cmetadata"]["id"]
             document = row["document"]
-            training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
+            
+            # 处理不同类型的ID后缀
+            if custom_id.endswith("-doc"):
+                training_data_type = "documentation"
+            elif custom_id.endswith("-error_sql"):
+                training_data_type = "error_sql"
+            elif custom_id.endswith("-sql"):
+                training_data_type = "sql"
+            elif custom_id.endswith("-ddl"):
+                training_data_type = "ddl"
+            else:
+                # If the suffix is not recognized, skip this row
+                logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
+                continue
 
-            if training_data_type == "sql":
+            if training_data_type in ["sql", "error_sql"]:
                 # Convert the document string to a dictionary
                 try:
                     doc_dict = ast.literal_eval(document)
@@ -314,10 +334,6 @@ class PG_VectorStore(VannaBase):
             elif training_data_type in ["documentation", "ddl"]:
                 question = None  # Default value for question
                 content = document
-            else:
-                # If the suffix is not recognized, skip this row
-                logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
-                continue
 
             # Append the processed data to the list
             processed_rows.append(
@@ -361,11 +377,11 @@ class PG_VectorStore(VannaBase):
         engine = create_engine(self.connection_string)
 
         # Determine the suffix to look for based on the collection name
-        suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
+        suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc", "error_sql": "error_sql"}
         suffix = suffix_map.get(collection_name)
 
         if not suffix:
-            logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
+            logging.info("Invalid collection name. Choose from 'ddl', 'sql', 'documentation', or 'error_sql'.")
             return False
 
         # SQL query to delete rows based on the condition
@@ -397,4 +413,84 @@ class PG_VectorStore(VannaBase):
                     return False
 
     def generate_embedding(self, *args, **kwargs):
-        pass
+        pass
+
+    # 增加错误SQL的训练和查询功能
+    # 1. 确保error_sql集合存在
+    def _ensure_error_sql_collection(self):
+        """确保error_sql集合存在"""
+        # 集合已在 __init__ 中初始化,这里只是为了保持方法的一致性
+        pass
+    
+    # 2. 将错误的question-sql对存储到error_sql集合中
+    def train_error_sql(self, question: str, sql: str, **kwargs) -> str:
+        """
+        将错误的question-sql对存储到error_sql集合中
+        """
+        # 确保集合存在
+        self._ensure_error_sql_collection()
+        
+        # 创建文档内容,格式与现有SQL训练数据一致
+        question_sql_json = json.dumps(
+            {
+                "question": question,
+                "sql": sql,
+                "type": "error_sql"
+            },
+            ensure_ascii=False,
+        )
+        
+        # 生成ID,使用与现有方法一致的格式
+        id = str(uuid.uuid4()) + "-error_sql"
+        createdat = kwargs.get("createdat")
+        
+        # 创建Document对象
+        doc = Document(
+            page_content=question_sql_json,
+            metadata={"id": id, "createdat": createdat},
+        )
+        
+        # 添加到error_sql集合
+        self.error_sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
+        
+        return id
+    
+    # 3. 获取相似的错误SQL示例
+    def get_error_sql_examples(self, question: str, limit: int = 5) -> list:
+        """
+        获取相似的错误SQL示例
+        """
+        # 确保集合存在
+        self._ensure_error_sql_collection()
+        
+        try:
+            docs_with_scores = self.error_sql_collection.similarity_search_with_score(
+                query=question,
+                k=limit
+            )
+            
+            results = []
+            for doc, score in docs_with_scores:
+                try:
+                    # 将文档内容转换为 dict,与现有方法保持一致
+                    base = ast.literal_eval(doc.page_content)
+                    
+                    # 计算相似度
+                    similarity = round(1 - score, 4)
+                    
+                    # 每条记录单独打印
+                    print(f"[DEBUG] Error SQL Match: {base.get('question', '')} | similarity: {similarity}")
+                    
+                    # 添加 similarity 字段
+                    base["similarity"] = similarity
+                    results.append(base)
+                    
+                except (ValueError, SyntaxError) as e:
+                    print(f"Error parsing error SQL document: {e}")
+                    continue
+            
+            return results
+            
+        except Exception as e:
+            print(f"Error retrieving error SQL examples: {e}")
+            return []