|
@@ -7,6 +7,8 @@ from vanna.base import VannaBase
|
|
from core.logging import get_vanna_logger
|
|
from core.logging import get_vanna_logger
|
|
# 导入配置参数
|
|
# 导入配置参数
|
|
from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_RESULT_THINKING
|
|
from app_config import REWRITE_QUESTION_ENABLED, DISPLAY_RESULT_THINKING
|
|
|
|
+# 导入提示词加载器
|
|
|
|
+from .load_prompts import get_prompt_loader
|
|
|
|
|
|
|
|
|
|
class BaseLLMChat(VannaBase, ABC):
|
|
class BaseLLMChat(VannaBase, ABC):
|
|
@@ -21,6 +23,9 @@ class BaseLLMChat(VannaBase, ABC):
|
|
# 存储LLM解释性文本
|
|
# 存储LLM解释性文本
|
|
self.last_llm_explanation = None
|
|
self.last_llm_explanation = None
|
|
|
|
|
|
|
|
+ # 初始化提示词加载器
|
|
|
|
+ self.prompt_loader = get_prompt_loader()
|
|
|
|
+
|
|
self.logger.info("传入的 config 参数如下:")
|
|
self.logger.info("传入的 config 参数如下:")
|
|
for key, value in self.config.items():
|
|
for key, value in self.config.items():
|
|
self.logger.info(f" {key}: {value}")
|
|
self.logger.info(f" {key}: {value}")
|
|
@@ -46,6 +51,37 @@ class BaseLLMChat(VannaBase, ABC):
|
|
self.logger.warning(f"无法加载错误SQL提示配置: {e},使用默认值 False")
|
|
self.logger.warning(f"无法加载错误SQL提示配置: {e},使用默认值 False")
|
|
return False
|
|
return False
|
|
|
|
|
|
|
|
+ def log(self, message: str, title: str = "Info"):
|
|
|
|
+ """
|
|
|
|
+ 重写父类的log方法,使用项目的日志系统替代print输出
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ message: 日志消息
|
|
|
|
+ title: 日志标题
|
|
|
|
+ """
|
|
|
|
+ # 将Vanna的log输出转换为项目的日志格式
|
|
|
|
+ if title == "SQL Prompt":
|
|
|
|
+ # 对于SQL Prompt,使用debug级别,避免输出过长的内容
|
|
|
|
+ # 将列表格式转换为字符串,只显示前200个字符
|
|
|
|
+ if isinstance(message, list):
|
|
|
|
+ message_str = str(message)[:200] + "..." if len(str(message)) > 200 else str(message)
|
|
|
|
+ else:
|
|
|
|
+ message_str = str(message)[:200] + "..." if len(str(message)) > 200 else str(message)
|
|
|
|
+ self.logger.debug(f"[Vanna] {title}: {message_str}")
|
|
|
|
+ elif title == "LLM Response":
|
|
|
|
+ # 对于LLM响应,记录但不显示全部内容
|
|
|
|
+ if isinstance(message, str):
|
|
|
|
+ message_str = message[:200] + "..." if len(message) > 200 else message
|
|
|
|
+ else:
|
|
|
|
+ message_str = str(message)[:200] + "..." if len(str(message)) > 200 else str(message)
|
|
|
|
+ self.logger.debug(f"[Vanna] {title}: {message_str}")
|
|
|
|
+ elif title == "Extracted SQL":
|
|
|
|
+ # 对于提取的SQL,使用info级别
|
|
|
|
+ self.logger.info(f"[Vanna] {title}: {message}")
|
|
|
|
+ else:
|
|
|
|
+ # 其他日志使用info级别
|
|
|
|
+ self.logger.info(f"[Vanna] {title}: {message}")
|
|
|
|
+
|
|
def system_message(self, message: str) -> dict:
|
|
def system_message(self, message: str) -> dict:
|
|
"""创建系统消息格式"""
|
|
"""创建系统消息格式"""
|
|
self.logger.debug(f"system_content: {message}")
|
|
self.logger.debug(f"system_content: {message}")
|
|
@@ -68,8 +104,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
self.logger.debug(f"开始生成SQL提示词,问题: {question}")
|
|
self.logger.debug(f"开始生成SQL提示词,问题: {question}")
|
|
|
|
|
|
if initial_prompt is None:
|
|
if initial_prompt is None:
|
|
- initial_prompt = f"You are a {self.dialect} expert. " + \
|
|
|
|
- "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions."
|
|
|
|
|
|
+ initial_prompt = self.prompt_loader.get_sql_initial_prompt(self.dialect)
|
|
|
|
|
|
# 提取DDL内容(适配新的字典格式)
|
|
# 提取DDL内容(适配新的字典格式)
|
|
ddl_content_list = []
|
|
ddl_content_list = []
|
|
@@ -125,30 +160,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
except Exception as e:
|
|
except Exception as e:
|
|
self.logger.warning(f"获取错误SQL示例失败: {e}")
|
|
self.logger.warning(f"获取错误SQL示例失败: {e}")
|
|
|
|
|
|
- initial_prompt += (
|
|
|
|
- "===Response Guidelines \n"
|
|
|
|
- "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
|
|
|
- "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
|
|
|
- "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
|
|
|
- "4. **Context Understanding**: If the question follows [CONTEXT]...[CURRENT] format, replace pronouns in [CURRENT] with specific entities from [CONTEXT].\n"
|
|
|
|
- " - Example: If context mentions 'Nancheng Service Area has the most stalls', and current question is 'How many dining stalls does this service area have?', \n"
|
|
|
|
- " interpret it as 'How many dining stalls does Nancheng Service Area have?'\n"
|
|
|
|
- "5. Please use the most relevant table(s). \n"
|
|
|
|
- "6. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
|
|
|
- f"7. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
|
|
|
- "8. 在生成 SQL 查询时,如果出现 ORDER BY 子句,请遵循以下规则:\n"
|
|
|
|
- " - 对所有的排序字段(如聚合字段 SUM()、普通列等),请在 ORDER BY 中显式添加 NULLS LAST。\n"
|
|
|
|
- " - 不论是否使用 LIMIT,只要排序字段存在,都必须添加 NULLS LAST,以防止 NULL 排在结果顶部。\n"
|
|
|
|
- " - 示例参考:\n"
|
|
|
|
- " - ORDER BY total DESC NULLS LAST\n"
|
|
|
|
- " - ORDER BY zf_order DESC NULLS LAST\n"
|
|
|
|
- " - ORDER BY SUM(c.customer_count) DESC NULLS LAST \n"
|
|
|
|
- "9. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
|
|
|
|
- " - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
|
|
|
|
- " - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
|
|
|
|
- " - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"
|
|
|
|
- " - 中文别名要准确反映字段的业务含义"
|
|
|
|
- )
|
|
|
|
|
|
+ initial_prompt += self.prompt_loader.get_sql_response_guidelines(self.dialect)
|
|
|
|
|
|
message_log = [self.system_message(initial_prompt)]
|
|
message_log = [self.system_message(initial_prompt)]
|
|
|
|
|
|
@@ -168,57 +180,15 @@ class BaseLLMChat(VannaBase, ABC):
|
|
"""
|
|
"""
|
|
重写父类方法,添加明确的中文图表指令
|
|
重写父类方法,添加明确的中文图表指令
|
|
"""
|
|
"""
|
|
- # 构建更智能的中文图表指令,根据问题和数据内容生成有意义的标签
|
|
|
|
- chinese_chart_instructions = (
|
|
|
|
- "使用中文创建图表,要求:\n"
|
|
|
|
- "1. 根据用户问题和数据内容,为图表生成有意义的中文标题\n"
|
|
|
|
- "2. 根据数据列的实际含义,为X轴和Y轴生成准确的中文标签\n"
|
|
|
|
- "3. 如果有图例,确保图例标签使用中文\n"
|
|
|
|
- "4. 所有文本(包括标题、轴标签、图例、数据标签等)都必须使用中文\n"
|
|
|
|
- "5. 标题应该简洁明了地概括图表要展示的内容\n"
|
|
|
|
- "6. 轴标签应该准确反映对应数据列的业务含义\n"
|
|
|
|
- "7. 选择最适合数据特点的图表类型(柱状图、折线图、饼图等)"
|
|
|
|
|
|
+ # 构建系统消息
|
|
|
|
+ system_msg = self.prompt_loader.get_chart_system_message(
|
|
|
|
+ question=question,
|
|
|
|
+ sql=sql,
|
|
|
|
+ df_metadata=df_metadata
|
|
)
|
|
)
|
|
|
|
|
|
- # 构建父类方法要求的message_log
|
|
|
|
- system_msg_parts = []
|
|
|
|
-
|
|
|
|
- if question:
|
|
|
|
- system_msg_parts.append(
|
|
|
|
- f"用户问题:'{question}'"
|
|
|
|
- )
|
|
|
|
- system_msg_parts.append(
|
|
|
|
- f"以下是回答用户问题的pandas DataFrame数据:"
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- system_msg_parts.append("以下是一个pandas DataFrame数据:")
|
|
|
|
-
|
|
|
|
- if sql:
|
|
|
|
- system_msg_parts.append(f"数据来源SQL查询:\n{sql}")
|
|
|
|
-
|
|
|
|
- system_msg_parts.append(f"DataFrame结构信息:\n{df_metadata}")
|
|
|
|
-
|
|
|
|
- system_msg = "\n\n".join(system_msg_parts)
|
|
|
|
-
|
|
|
|
- # 构建更详细的用户消息,强调中文标签的重要性
|
|
|
|
- user_msg = (
|
|
|
|
- "请为这个DataFrame生成Python Plotly可视化代码。要求:\n\n"
|
|
|
|
- "1. 假设数据存储在名为'df'的pandas DataFrame中\n"
|
|
|
|
- "2. 如果DataFrame只有一个值,使用Indicator图表\n"
|
|
|
|
- "3. 只返回Python代码,不要任何解释\n"
|
|
|
|
- "4. 代码必须可以直接运行\n\n"
|
|
|
|
- f"{chinese_chart_instructions}\n\n"
|
|
|
|
- "特别注意:\n"
|
|
|
|
- "- 不要使用'图表标题'、'X轴标签'、'Y轴标签'这样的通用标签\n"
|
|
|
|
- "- 要根据实际数据内容和用户问题生成具体、有意义的中文标签\n"
|
|
|
|
- "- 例如:如果是性别统计,X轴可能是'性别',Y轴可能是'人数'或'占比'\n"
|
|
|
|
- "- 标题应该概括图表的主要内容,如'男女持卡比例分布'\n\n"
|
|
|
|
- "数据标签和悬停信息要求:\n"
|
|
|
|
- "- 不要使用%{text}这样的占位符变量\n"
|
|
|
|
- "- 使用具体的数据值和中文单位,例如:text=df['列名'].astype(str) + '人'\n"
|
|
|
|
- "- 悬停信息要清晰易懂,使用中文描述\n"
|
|
|
|
- "- 确保所有显示的文本都是实际的数据值,不是变量占位符"
|
|
|
|
- )
|
|
|
|
|
|
+ # 构建用户消息
|
|
|
|
+ user_msg = self.prompt_loader.get_chart_user_message()
|
|
|
|
|
|
message_log = [
|
|
message_log = [
|
|
self.system_message(system_msg),
|
|
self.system_message(system_msg),
|
|
@@ -369,7 +339,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
"""根据SQL生成中文问题"""
|
|
"""根据SQL生成中文问题"""
|
|
prompt = [
|
|
prompt = [
|
|
self.system_message(
|
|
self.system_message(
|
|
- "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
|
|
|
|
|
|
+ self.prompt_loader.get_question_generation_prompt()
|
|
),
|
|
),
|
|
self.user_message(sql)
|
|
self.user_message(sql)
|
|
]
|
|
]
|
|
@@ -413,9 +383,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
try:
|
|
try:
|
|
# 如果没有提供自定义系统提示词,使用默认的
|
|
# 如果没有提供自定义系统提示词,使用默认的
|
|
if system_prompt is None:
|
|
if system_prompt is None:
|
|
- system_prompt = (
|
|
|
|
- "你是一个友好的AI助手,请用中文回答用户的问题。"
|
|
|
|
- )
|
|
|
|
|
|
+ system_prompt = self.prompt_loader.get_chat_default_prompt()
|
|
|
|
|
|
prompt = [
|
|
prompt = [
|
|
self.system_message(system_prompt),
|
|
self.system_message(system_prompt),
|
|
@@ -460,9 +428,7 @@ class BaseLLMChat(VannaBase, ABC):
|
|
try:
|
|
try:
|
|
prompt = [
|
|
prompt = [
|
|
self.system_message(
|
|
self.system_message(
|
|
- "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
|
|
|
|
- "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
|
|
|
|
- "请用中文回答。"
|
|
|
|
|
|
+ self.prompt_loader.get_question_merge_prompt()
|
|
),
|
|
),
|
|
self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
|
|
self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
|
|
]
|
|
]
|
|
@@ -511,18 +477,13 @@ class BaseLLMChat(VannaBase, ABC):
|
|
self.logger.debug(f"DataFrame 形状: {df.shape}")
|
|
self.logger.debug(f"DataFrame 形状: {df.shape}")
|
|
|
|
|
|
# 构建包含中文指令的系统消息
|
|
# 构建包含中文指令的系统消息
|
|
- system_content = (
|
|
|
|
- f"你是一个专业的数据分析助手。用户提出了问题:'{question}'\n\n"
|
|
|
|
- f"以下是查询结果的 pandas DataFrame 数据:\n{df.to_markdown()}\n\n"
|
|
|
|
- "请用中文进行思考和分析,并用中文回答。"
|
|
|
|
|
|
+ system_content = self.prompt_loader.get_summary_system_message(
|
|
|
|
+ question=question,
|
|
|
|
+ df_markdown=df.to_markdown()
|
|
)
|
|
)
|
|
|
|
|
|
# 构建用户消息,强调中文思考和回答
|
|
# 构建用户消息,强调中文思考和回答
|
|
- user_content = (
|
|
|
|
- "请基于用户提出的问题,简要总结这些数据。要求:\n"
|
|
|
|
- "1. 只进行简要总结,不要添加额外的解释\n"
|
|
|
|
- "2. 如果数据中有数字,请保留适当的精度\n"
|
|
|
|
- )
|
|
|
|
|
|
+ user_content = self.prompt_loader.get_summary_user_instructions()
|
|
|
|
|
|
message_log = [
|
|
message_log = [
|
|
self.system_message(system_content),
|
|
self.system_message(system_content),
|