|
@@ -51,7 +51,21 @@ class QianWenAI_Chat(VannaBase):
|
|
else:
|
|
else:
|
|
self.client = OpenAI(api_key=config["api_key"],
|
|
self.client = OpenAI(api_key=config["api_key"],
|
|
base_url=config["base_url"])
|
|
base_url=config["base_url"])
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+ # 新增:加载错误SQL提示配置
|
|
|
|
+ self.enable_error_sql_prompt = self._load_error_sql_prompt_config()
|
|
|
|
+
|
|
|
|
+ def _load_error_sql_prompt_config(self) -> bool:
|
|
|
|
+ """从app_config.py加载错误SQL提示配置"""
|
|
|
|
+ try:
|
|
|
|
+ import app_config
|
|
|
|
+ enable_error_sql = getattr(app_config, 'ENABLE_ERROR_SQL_PROMPT', False)
|
|
|
|
+ print(f"[DEBUG] 错误SQL提示配置: ENABLE_ERROR_SQL_PROMPT = {enable_error_sql}")
|
|
|
|
+ return enable_error_sql
|
|
|
|
+ except (ImportError, AttributeError) as e:
|
|
|
|
+ print(f"[WARNING] 无法加载错误SQL提示配置: {e},使用默认值 False")
|
|
|
|
+ return False
|
|
|
|
+
|
|
# 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
|
|
# 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
|
|
def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
|
|
def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
|
|
"""
|
|
"""
|
|
@@ -92,6 +106,31 @@ class QianWenAI_Chat(VannaBase):
|
|
initial_prompt, doc_content_list, max_tokens=self.max_tokens
|
|
initial_prompt, doc_content_list, max_tokens=self.max_tokens
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ # 新增:添加错误SQL示例作为负面示例(放在Response Guidelines之前)
|
|
|
|
+ if self.enable_error_sql_prompt:
|
|
|
|
+ try:
|
|
|
|
+ error_sql_list = self.get_related_error_sql(question, **kwargs)
|
|
|
|
+ if error_sql_list:
|
|
|
|
+ print(f"[DEBUG] 找到 {len(error_sql_list)} 个相关的错误SQL示例")
|
|
|
|
+
|
|
|
|
+ # 构建格式化的负面提示内容
|
|
|
|
+ negative_prompt_content = "===Negative Examples\n"
|
|
|
|
+ negative_prompt_content += "下面是错误的SQL示例,请分析这些错误SQL的问题所在,并在生成新SQL时避免类似错误:\n\n"
|
|
|
|
+
|
|
|
|
+ for i, error_example in enumerate(error_sql_list, 1):
|
|
|
|
+ if "question" in error_example and "sql" in error_example:
|
|
|
|
+ similarity = error_example.get('similarity', 'N/A')
|
|
|
|
+ print(f"[DEBUG] 错误SQL示例 {i}: 相似度={similarity}")
|
|
|
|
+ negative_prompt_content += f"问题: {error_example['question']}\n"
|
|
|
|
+ negative_prompt_content += f"错误的SQL: {error_example['sql']}\n\n"
|
|
|
|
+
|
|
|
|
+ # 将负面提示添加到初始提示中
|
|
|
|
+ initial_prompt += negative_prompt_content
|
|
|
|
+ else:
|
|
|
|
+ print("[DEBUG] 未找到相关的错误SQL示例")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"[WARNING] 获取错误SQL示例失败: {e}")
|
|
|
|
+
|
|
initial_prompt += (
|
|
initial_prompt += (
|
|
"===Response Guidelines \n"
|
|
"===Response Guidelines \n"
|
|
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
|
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
|
@@ -102,14 +141,9 @@ class QianWenAI_Chat(VannaBase):
|
|
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
|
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
|
"7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
|
|
"7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
|
|
" - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
|
|
" - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
|
|
- " - 包括原始字段名也要添加中文别名,例如:gender AS 性别, card_category AS 卡片类型\n"
|
|
|
|
- " - 计算字段也要有中文别名,例如:COUNT(*) AS 持卡人数\n"
|
|
|
|
- " - 中文别名要准确反映字段的业务含义\n"
|
|
|
|
- " - 绝对不能有任何字段没有中文别名,这会影响表格的可读性\n"
|
|
|
|
- " - 这样可以提高图表的可读性和用户体验\n"
|
|
|
|
- " - 不要在where条件中使用中文别名,比如: WHERE gender = 'F' AS 性别, 这是错误的语法\n"
|
|
|
|
- " 正确示例:SELECT gender AS 性别, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
|
|
|
|
- " 错误示例:SELECT gender, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
|
|
|
|
|
|
+ " - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
|
|
|
|
+ " - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
|
|
|
|
+ " - 中文别名要准确反映字段的业务含义"
|
|
)
|
|
)
|
|
|
|
|
|
message_log = [self.system_message(initial_prompt)]
|
|
message_log = [self.system_message(initial_prompt)]
|
|
@@ -124,7 +158,6 @@ class QianWenAI_Chat(VannaBase):
|
|
|
|
|
|
message_log.append(self.user_message(question))
|
|
message_log.append(self.user_message(question))
|
|
|
|
|
|
- print(f"[DEBUG] SQL提示词生成完成,消息数量: {len(message_log)}")
|
|
|
|
return message_log
|
|
return message_log
|
|
|
|
|
|
# 生成图形的时候,使用中文标注
|
|
# 生成图形的时候,使用中文标注
|