Ver código fonte

增加了反向提示词的功能,在app_config.py中增加了开关,会查询error_sql类型的记录,做为反向提示词。

wangxq 3 semanas atrás
pai
commit
23ce2d195a
4 arquivos alterados com 187 adições e 15 exclusões
  1. 3 0
      app_config.py
  2. 64 5
      custompgvector/pgvector.py
  3. 43 10
      customqianwen/Custom_QianwenAI_chat.py
  4. 77 0
      docs/release.md

+ 3 - 0
app_config.py

@@ -129,3 +129,6 @@ ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True
 RESULT_VECTOR_SQL_SCORE_THRESHOLD = 0.65
 RESULT_VECTOR_DDL_SCORE_THRESHOLD = 0.5
 RESULT_VECTOR_DOC_SCORE_THRESHOLD = 0.5
+RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD = 0.5
+
+ENABLE_ERROR_SQL_PROMPT = True

+ 64 - 5
custompgvector/pgvector.py

@@ -260,6 +260,62 @@ class PG_VectorStore(VannaBase):
         
         return filtered_results
 
+    def _apply_error_sql_threshold_filter(self, results: list) -> list:
+        """
+        应用错误SQL特有的相似度阈值过滤逻辑
+        
+        与其他方法不同,错误SQL的过滤逻辑是:
+        - 只返回相似度高于阈值的结果
+        - 不设置最低返回数量
+        - 如果都低于阈值,返回空列表
+        
+        Args:
+            results: 原始结果列表,每个元素包含 similarity 字段
+            
+        Returns:
+            过滤后的结果列表
+        """
+        if not results:
+            return results
+            
+        # 导入配置
+        try:
+            import app_config
+            enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
+            threshold = getattr(app_config, 'RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD', 0.5)
+        except (ImportError, AttributeError) as e:
+            print(f"[WARNING] 无法加载错误SQL阈值配置: {e},使用默认值")
+            enable_threshold = False
+            threshold = 0.5
+        
+        # 如果未启用阈值过滤,直接返回原结果
+        if not enable_threshold:
+            print(f"[DEBUG] Error SQL 阈值过滤未启用,返回全部 {len(results)} 条结果")
+            return results
+        
+        total_count = len(results)
+        print(f"[DEBUG] Error SQL 阈值过滤: 总数={total_count}, 阈值={threshold}")
+        
+        # 按相似度降序排序(确保最相似的在前面)
+        sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
+        
+        # 只保留满足阈值的结果,不设置最低返回数量
+        filtered_results = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
+        
+        filtered_count = len(filtered_results)
+        filtered_out_count = total_count - filtered_count
+        
+        if filtered_count > 0:
+            print(f"[DEBUG] Error SQL 过滤结果: 保留 {filtered_count} 条, 过滤掉 {filtered_out_count} 条")
+            # 打印保留的结果详情
+            for i, result in enumerate(filtered_results):
+                similarity = result.get('similarity', 0)
+                print(f"[DEBUG] Error SQL 保留 {i+1}: similarity={similarity} ✓")
+        else:
+            print(f"[DEBUG] Error SQL 过滤结果: 所有 {total_count} 条结果都低于阈值 {threshold},返回空列表")
+        
+        return filtered_results
+
     def train(
         self,
         question: str | None = None,
@@ -455,10 +511,10 @@ class PG_VectorStore(VannaBase):
         
         return id
     
-    # 3. 获取相的错误SQL示例
-    def get_error_sql_examples(self, question: str, limit: int = 5) -> list:
+    # 3. 获取相的错误SQL示例
+    def get_related_error_sql(self, question: str, **kwargs) -> list:
         """
-        获取相的错误SQL示例
+        获取相的错误SQL示例
         """
         # 确保集合存在
         self._ensure_error_sql_collection()
@@ -466,7 +522,7 @@ class PG_VectorStore(VannaBase):
         try:
             docs_with_scores = self.error_sql_collection.similarity_search_with_score(
                 query=question,
-                k=limit
+                k=self.n_results
             )
             
             results = []
@@ -489,7 +545,10 @@ class PG_VectorStore(VannaBase):
                     print(f"Error parsing error SQL document: {e}")
                     continue
             
-            return results
+            # 应用错误SQL特有的阈值过滤逻辑
+            filtered_results = self._apply_error_sql_threshold_filter(results)
+            
+            return filtered_results
             
         except Exception as e:
             print(f"Error retrieving error SQL examples: {e}")

+ 43 - 10
customqianwen/Custom_QianwenAI_chat.py

@@ -51,7 +51,21 @@ class QianWenAI_Chat(VannaBase):
             else:
                 self.client = OpenAI(api_key=config["api_key"],
                                      base_url=config["base_url"])
-                
+        
+        # 新增:加载错误SQL提示配置
+        self.enable_error_sql_prompt = self._load_error_sql_prompt_config()
+
+    def _load_error_sql_prompt_config(self) -> bool:
+        """从app_config.py加载错误SQL提示配置"""
+        try:
+            import app_config
+            enable_error_sql = getattr(app_config, 'ENABLE_ERROR_SQL_PROMPT', False)
+            print(f"[DEBUG] 错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
+            return enable_error_sql
+        except (ImportError, AttributeError) as e:
+            print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
+            return False
+                 
     # 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
     def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
         """
@@ -92,6 +106,31 @@ class QianWenAI_Chat(VannaBase):
             initial_prompt, doc_content_list, max_tokens=self.max_tokens
         )
 
+        # 新增:添加错误SQL示例作为负面示例(放在Response Guidelines之前)
+        if self.enable_error_sql_prompt:
+            try:
+                error_sql_list = self.get_related_error_sql(question, **kwargs)
+                if error_sql_list:
+                    print(f"[DEBUG] 找到 {len(error_sql_list)} 个相关的错误SQL示例")
+                    
+                    # 构建格式化的负面提示内容
+                    negative_prompt_content = "===Negative Examples\n"
+                    negative_prompt_content += "下面是错误的SQL示例,请分析这些错误SQL的问题所在,并在生成新SQL时避免类似错误:\n\n"
+                    
+                    for i, error_example in enumerate(error_sql_list, 1):
+                        if "question" in error_example and "sql" in error_example:
+                            similarity = error_example.get('similarity', 'N/A')
+                            print(f"[DEBUG] 错误SQL示例 {i}: 相似度={similarity}")
+                            negative_prompt_content += f"问题: {error_example['question']}\n"
+                            negative_prompt_content += f"错误的SQL: {error_example['sql']}\n\n"
+                    
+                    # 将负面提示添加到初始提示中
+                    initial_prompt += negative_prompt_content
+                else:
+                    print("[DEBUG] 未找到相关的错误SQL示例")
+            except Exception as e:
+                print(f"[WARNING] 获取错误SQL示例失败: {e}")
+
         initial_prompt += (
             "===Response Guidelines \n"
             "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
@@ -102,14 +141,9 @@ class QianWenAI_Chat(VannaBase):
             f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
             "7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
             "   - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
-            "   - 包括原始字段名也要添加中文别名,例如:gender AS 性别, card_category AS 卡片类型\n"
-            "   - 计算字段也要有中文别名,例如:COUNT(*) AS 持卡人数\n"
-            "   - 中文别名要准确反映字段的业务含义\n"
-            "   - 绝对不能有任何字段没有中文别名,这会影响表格的可读性\n"
-            "   - 这样可以提高图表的可读性和用户体验\n"
-            "   - 不要在where条件中使用中文别名,比如: WHERE gender = 'F' AS 性别, 这是错误的语法\n"
-            "   正确示例:SELECT gender AS 性别, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
-            "   错误示例:SELECT gender, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
+            "   - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
+            "   - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
+            "   - 中文别名要准确反映字段的业务含义"
         )
 
         message_log = [self.system_message(initial_prompt)]
@@ -124,7 +158,6 @@ class QianWenAI_Chat(VannaBase):
 
         message_log.append(self.user_message(question))
         
-        print(f"[DEBUG] SQL提示词生成完成,消息数量: {len(message_log)}")
         return message_log
 
     # 生成图形的时候,使用中文标注

+ 77 - 0
docs/release.md

@@ -12,3 +12,80 @@ RESULT_VECTOR_SQL_SCORE_THRESHOLD = 0.65
 RESULT_VECTOR_DDL_SCORE_THRESHOLD = 0.5
 RESULT_VECTOR_DOC_SCORE_THRESHOLD = 0.5
 
+
+3.增加了错误SQL负面示例提示功能,用于提高SQL生成质量。
+
+## 功能概述
+通过向LLM提供相关的错误SQL示例作为负面示例,帮助LLM避免生成类似的错误SQL,从而提高SQL生成的准确性和质量。
+
+## 配置参数
+```python
+# 是否启用错误SQL提示功能
+ENABLE_ERROR_SQL_PROMPT = True
+
+# 错误SQL相似度阈值(仅返回相似度高于此阈值的错误SQL示例)
+RESULT_VECTOR_ERROR_SQL_SCORE_THRESHOLD = 0.5
+```
+
+## 实现细节
+
+### 1. 数据存储
+- 在PgVector数据库中新增 `error_sql` 集合,用于存储错误的question-sql对
+- 错误SQL数据格式:`{"question": "用户问题", "sql": "错误的SQL", "type": "error_sql"}`
+- 支持通过训练接口添加错误SQL示例
+
+### 2. 向量查询
+- 新增 `get_related_error_sql()` 方法,基于问题相似度查找相关的错误SQL示例
+- 使用与其他向量查询一致的相似度计算和阈值过滤机制
+- 错误SQL的阈值过滤逻辑:严格按阈值过滤,不设置最低返回数量
+
+### 3. 提示词集成
+- 在SQL生成过程中,如果找到相关的错误SQL示例,会在Response Guidelines之前添加负面示例
+- 负面示例采用与现有提示词一致的格式结构,位于Response Guidelines之前:
+  ```
+  ===Tables
+  [DDL信息]
+
+  ===Additional Context
+  [文档信息]
+
+  ===Negative Examples
+  下面是错误的SQL示例,请分析这些错误SQL的问题所在,并在生成新SQL时避免类似错误:
+
+  问题: [用户问题]
+  错误的SQL: [错误SQL]
+
+  问题: [用户问题]
+  错误的SQL: [错误SQL]
+
+  ===Response Guidelines
+  [响应指南和中文别名要求]
+  ```
+
+### 4. 智能过滤
+- 只有当 `ENABLE_ERROR_SQL_PROMPT = True` 且找到相关错误SQL示例时,才会添加负面提示词
+- 如果未找到相关错误SQL示例(返回空列表),不会添加任何负面提示词
+- 支持相似度阈值过滤,只使用高质量的相关错误示例
+
+## 使用方式
+
+### 1. 训练错误SQL示例
+```python
+# 通过API接口添加错误SQL示例
+POST /api/v0/training_error_question_sql
+{
+    "question": "查询所有用户信息",
+    "sql": "SELECT * FROM users WHERE id = 'all'"  // 错误的SQL
+}
+```
+
+### 2. 自动应用
+- 配置启用后,系统会在每次SQL生成时自动查找并应用相关的错误SQL示例
+- 无需额外操作,透明集成到现有的SQL生成流程中
+
+## 技术优势
+1. **智能相关性匹配**:基于向量相似度找到与当前问题最相关的错误示例
+2. **质量控制**:通过相似度阈值确保只使用高质量的相关错误示例
+3. **性能优化**:只在找到相关错误示例时才添加负面提示词,避免不必要的token消耗
+4. **灵活配置**:支持通过配置参数灵活控制功能开关和阈值设置
+