Pārlūkot izejas kodu

为字段添加中文别名,生成中文图形。

wangxq 1 mēnesi atpakaļ
vecāks
revīzija
39b86b08cb
2 mainītis faili ar 134 papildinājumiem un 4 dzēšanām
  1. 10 3
      app.py
  2. 124 1
      customqianwen/Custom_QianwenAI_chat.py

+ 10 - 3
app.py

@@ -65,7 +65,10 @@ async def execute_query(query):
         if df is None or df.empty:
             current_step.output = "查询执行成功,但没有返回数据"
             return None
-            
+        
+        # 调试信息:检查DataFrame的列名
+        print(f"[DEBUG] 执行后DataFrame列名: {list(df.columns)}")
+        print(f"[DEBUG] DataFrame数据预览:\n{df.head()}")
         current_step.output = df.head().to_markdown(index=False)
         print(f"[SUCCESS] SQL执行成功,返回 {len(df)} 行数据")
         return df
@@ -84,7 +87,9 @@ async def plot(human_query, sql, df):
             current_step.output = "无数据可用于生成图表"
             return None
             
-        plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df=df)
+        # 生成DataFrame的元数据信息
+        df_metadata = f"列名: {list(df.columns)}\n数据类型: {df.dtypes.to_dict()}\n数据形状: {df.shape}\n前几行数据:\n{df.head().to_string()}"
+        plotly_code = vn.generate_plotly_code(question=human_query, sql=sql, df_metadata=df_metadata)
         fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
         current_step.output = plotly_code
         return fig
@@ -239,6 +244,8 @@ async def chain(human_query: str):
         fig = await plot(human_query, sql_query, df)
 
         # 创建返回元素
+        # 确保表格显示中文列名
+        print(f"[DEBUG] DataFrame列名: {list(df.columns)}")
         elements = [
             cl.Text(name="data_table", content=df.to_markdown(index=False), display="inline")
         ]
@@ -292,7 +299,7 @@ async def on_chat_start():
 
 请直接输入您的问题,例如:
 - "交易次数最多的前5位客户是谁?"
-- "查看过去30天的交易趋势"
+- "请统计不同类型的卡持卡人数所占的百分比?"
 - "你好,今天天气怎么样?"
 
 让我们开始吧!✨

+ 124 - 1
customqianwen/Custom_QianwenAI_chat.py

@@ -49,7 +49,130 @@ class QianWenAI_Chat(VannaBase):
             else:
                 self.client = OpenAI(api_key=config["api_key"],
                                      base_url=config["base_url"])
-   
+                
+    # 生成SQL的时候,使用中文别名 - 基于VannaBase源码直接实现
+    def get_sql_prompt(self, initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
+        """
+        基于VannaBase源码实现,在第7点添加中文别名指令
+        """
+        print(f"[DEBUG] 开始生成SQL提示词,问题: {question}")
+        
+        if initial_prompt is None:
+            initial_prompt = f"You are a {self.dialect} expert. " + \
+            "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
+
+        initial_prompt = self.add_ddl_to_prompt(
+            initial_prompt, ddl_list, max_tokens=self.max_tokens
+        )
+
+        if self.static_documentation != "":
+            doc_list.append(self.static_documentation)
+
+        initial_prompt = self.add_documentation_to_prompt(
+            initial_prompt, doc_list, max_tokens=self.max_tokens
+        )
+
+        initial_prompt += (
+            "===Response Guidelines \n"
+            "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
+            "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
+            "3. If the provided context is insufficient, please explain why it can't be generated. \n"
+            "4. Please use the most relevant table(s). \n"
+            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
+            f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
+            "7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
+            "   - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
+            "   - 包括原始字段名也要添加中文别名,例如:gender AS 性别, card_category AS 卡片类型\n"
+            "   - 计算字段也要有中文别名,例如:COUNT(*) AS 持卡人数\n"
+            "   - 中文别名要准确反映字段的业务含义\n"
+            "   - 绝对不能有任何字段没有中文别名,这会影响表格的可读性\n"
+            "   - 这样可以提高图表的可读性和用户体验\n"
+            "   正确示例:SELECT gender AS 性别, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
+            "   错误示例:SELECT gender, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
+        )
+
+        message_log = [self.system_message(initial_prompt)]
+
+        for example in question_sql_list:
+            if example is None:
+                print("example is None")
+            else:
+                if example is not None and "question" in example and "sql" in example:
+                    message_log.append(self.user_message(example["question"]))
+                    message_log.append(self.assistant_message(example["sql"]))
+
+        message_log.append(self.user_message(question))
+        
+        print(f"[DEBUG] SQL提示词生成完成,消息数量: {len(message_log)}")
+        return message_log
+
+    # 生成图形的时候,使用中文标注
+    def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
+        """
+        重写父类方法,添加明确的中文图表指令
+        """
+        # 构建更智能的中文图表指令,根据问题和数据内容生成有意义的标签
+        chinese_chart_instructions = (
+            "使用中文创建图表,要求:\n"
+            "1. 根据用户问题和数据内容,为图表生成有意义的中文标题\n"
+            "2. 根据数据列的实际含义,为X轴和Y轴生成准确的中文标签\n"
+            "3. 如果有图例,确保图例标签使用中文\n"
+            "4. 所有文本(包括标题、轴标签、图例、数据标签等)都必须使用中文\n"
+            "5. 标题应该简洁明了地概括图表要展示的内容\n"
+            "6. 轴标签应该准确反映对应数据列的业务含义\n"
+            "7. 选择最适合数据特点的图表类型(柱状图、折线图、饼图等)"
+        )
+
+        # 构建父类方法要求的message_log
+        system_msg_parts = []
+
+        if question:
+            system_msg_parts.append(
+                f"用户问题:'{question}'"
+            )
+            system_msg_parts.append(
+                f"以下是回答用户问题的pandas DataFrame数据:"
+            )
+        else:
+            system_msg_parts.append("以下是一个pandas DataFrame数据:")
+
+        if sql:
+            system_msg_parts.append(f"数据来源SQL查询:\n{sql}")
+
+        system_msg_parts.append(f"DataFrame结构信息:\n{df_metadata}")
+
+        system_msg = "\n\n".join(system_msg_parts)
+
+        # 构建更详细的用户消息,强调中文标签的重要性
+        user_msg = (
+            "请为这个DataFrame生成Python Plotly可视化代码。要求:\n\n"
+            "1. 假设数据存储在名为'df'的pandas DataFrame中\n"
+            "2. 如果DataFrame只有一个值,使用Indicator图表\n"
+            "3. 只返回Python代码,不要任何解释\n"
+            "4. 代码必须可以直接运行\n\n"
+            f"{chinese_chart_instructions}\n\n"
+            "特别注意:\n"
+            "- 不要使用'图表标题'、'X轴标签'、'Y轴标签'这样的通用标签\n"
+            "- 要根据实际数据内容和用户问题生成具体、有意义的中文标签\n"
+            "- 例如:如果是性别统计,X轴可能是'性别',Y轴可能是'人数'或'占比'\n"
+            "- 标题应该概括图表的主要内容,如'男女持卡比例分布'\n\n"
+            "数据标签和悬停信息要求:\n"
+            "- 不要使用%{text}这样的占位符变量\n"
+            "- 使用具体的数据值和中文单位,例如:text=df['列名'].astype(str) + '人'\n"
+            "- 悬停信息要清晰易懂,使用中文描述\n"
+            "- 确保所有显示的文本都是实际的数据值,不是变量占位符"
+        )
+
+        message_log = [
+            self.system_message(system_msg),
+            self.user_message(user_msg),
+        ]
+
+        # 调用父类submit_prompt方法,并清理结果
+        plotly_code = self.submit_prompt(message_log, kwargs=kwargs)
+
+        return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
+    
     def system_message(self, message: str) -> any:
         print(f"system_content: {message}")
         return {"role": "system", "content": message}