Browse Source

将数据库判断的提示词抽取为txt文件,由生成训练数据的时候,一起生成,然后动态提供给代码使用。另外,修复了无法生成SQL时,无法传递LLM response的问题。

wangxq 1 tuần trước cách đây
mục cha
commit
5642f785ff

+ 19 - 6
agent/citu_agent.py

@@ -227,7 +227,8 @@ class CituLangGraphAgent:
                 error_message = sql_result.get("error", "")
                 error_type = sql_result.get("error_type", "")
                 
-                print(f"[SQL_GENERATION] SQL生成失败: {error_message}")
+                #print(f"[SQL_GENERATION] SQL生成失败: {error_message}")
+                print(f"[DEBUG] error_type = '{error_type}'")
                 
                 # 根据错误类型生成用户提示
                 if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
@@ -236,9 +237,15 @@ class CituLangGraphAgent:
                 elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
                     user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
                     failure_reason = "ambiguous_question"
-                elif error_type == "llm_explanation":
-                    user_prompt = error_message + " 请尝试重新描述您的问题或询问其他内容。"
-                    failure_reason = "llm_explanation"
+                elif error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
+                    # 对于解释性文本,直接设置为聊天响应
+                    state["chat_response"] = error_message + " 请尝试提问其它问题。"
+                    state["sql_generation_success"] = False
+                    state["validation_error_type"] = "llm_explanation"
+                    state["current_step"] = "sql_generation_completed"
+                    state["execution_path"].append("agent_sql_generation")
+                    print(f"[SQL_GENERATION] 返回LLM解释性答案: {error_message}")
+                    return state
                 else:
                     user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
                     failure_reason = "unknown_generation_failure"
@@ -255,11 +262,10 @@ class CituLangGraphAgent:
             
             sql = sql_result.get("sql")
             state["sql"] = sql
-            print(f"[SQL_GENERATION] SQL生成成功: {sql}")
             
             # 步骤1.5:检查是否为解释性响应而非SQL
             error_type = sql_result.get("error_type")
-            if error_type == "llm_explanation":
+            if error_type == "llm_explanation" or error_type == "generation_failed_with_explanation":
                 # LLM返回了解释性文本,直接作为最终答案
                 explanation = sql_result.get("error", "")
                 state["chat_response"] = explanation + " 请尝试提问其它问题。"
@@ -270,6 +276,13 @@ class CituLangGraphAgent:
                 print(f"[SQL_GENERATION] 返回LLM解释性答案: {explanation}")
                 return state
             
+            if sql:
+                print(f"[SQL_GENERATION] SQL生成成功: {sql}")
+            else:
+                print(f"[SQL_GENERATION] SQL为空,但不是解释性响应")
+                # 这种情况应该很少见,但为了安全起见保留原有的错误处理
+                return state
+            
             # 额外验证:检查SQL格式(防止工具误判)
             from agent.tools.utils import _is_valid_sql_format
             if not _is_valid_sql_format(sql):

+ 43 - 27
agent/classifier.py

@@ -409,44 +409,52 @@ class QuestionClassifier:
                 method="rule_based_uncertain"
             )
     
+    def _load_business_context(self) -> str:
+        """从文件中加载数据库业务范围描述"""
+        try:
+            import os
+            current_dir = os.path.dirname(os.path.abspath(__file__))
+            prompt_file = os.path.join(current_dir, "tools", "db_query_decision_prompt.txt")
+            
+            with open(prompt_file, 'r', encoding='utf-8') as f:
+                content = f.read().strip()
+                
+            if not content:
+                raise ValueError("业务上下文文件为空")
+                
+            return content
+            
+        except FileNotFoundError:
+            error_msg = f"无法找到业务上下文文件: {prompt_file}"
+            print(f"[ERROR] {error_msg}")
+            raise FileNotFoundError(error_msg)
+        except Exception as e:
+            error_msg = f"读取业务上下文文件失败: {str(e)}"
+            print(f"[ERROR] {error_msg}")
+            raise RuntimeError(error_msg)
+
     def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
         """增强的LLM分类:包含详细的业务上下文"""
         try:
             from common.vanna_instance import get_vanna_instance
             vn = get_vanna_instance()
             
+            # 动态加载业务上下文(如果失败会抛出异常)
+            business_context = self._load_business_context()
+            
             # 构建包含业务上下文的分类提示词
             classification_prompt = f"""
-请判断以下用户问题是否需要查询我们的高速公路服务区管理数据库。
+请判断以下用户问题是否需要查询我们的数据库。
 
 用户问题:{question}
 
-=== 数据库业务范围 ===
-本系统是高速公路服务区商业管理系统,包含以下业务数据:
-
-核心业务实体:
-- 服务区(bss_service_area):服务区基础信息、位置、状态,如"鄱阳湖服务区"、"信丰西服务区"
-- 档口/商铺(bss_branch):档口信息、品类(餐饮/小吃/便利店)、品牌,如"驿美餐饮"、"加水机"
-- 营业数据(bss_business_day_data):每日支付金额、订单数量,包含微信、支付宝、现金等支付方式
-- 车流量(bss_car_day_count):按车型统计的日流量数据,包含客车、货车、过境、危化品等
-- 公司信息(bss_company):服务区管理公司,如"驿美运营公司"
-
-关键业务指标:
-- 支付方式:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)、行吧支付(xs)、金豆支付(jd)
-- 营业数据:支付金额、订单数量、营业额、收入统计
-- 车流统计:按车型(客车/货车/过境/危化品/城际)的流量分析
-- 经营分析:餐饮、小吃、便利店、整体租赁等品类收入
-- 地理分区:北区、南区、西区、东区、两区
-
-高速线路:
-- 线路信息:大广、昌金、昌栗等高速线路
-- 路段管理:按线路统计服务区分布
+{business_context}
 
 === 判断标准 ===
 1. **DATABASE类型** - 需要查询数据库:
    - 涉及上述业务实体和指标的查询、统计、分析、报表
-   - 包含业务相关的时间查询,如"本月服务区营业额"、"上月档口收入"
-   - 例如:"本月营业额统计"、"档口收入排行"、"车流量分析"、"支付方式占比"
+   - 包含业务相关的时间查询
+   - 例如:业务数据统计、收入排行、流量分析、占比分析等
 
 2. **CHAT类型** - 不需要查询数据库:
    - 生活常识:水果蔬菜上市时间、动植物知识、天气等
@@ -458,7 +466,7 @@ class QuestionClassifier:
    - 商业:股票、基金、理财、投资、经济、通货膨胀、上市
    - 哲学:人生意义、价值观、道德、信仰、宗教、爱情
    - 政策:政策、法规、法律、条例、指南、手册、规章制度、实施细则
-   - 地理:全球、国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
+   - 地理:全球、国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
    - 体育:足球、NBA、篮球、乒乓球、冠军、夺冠
    - 文学:小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠
    - 娱乐:游戏、小说、新闻、政治、战争、足球、NBA、篮球、乒乓球、冠军、夺冠、电影、电视剧、音乐、舞蹈、绘画、书法、摄影、雕塑、建筑、设计、
@@ -466,7 +474,6 @@ class QuestionClassifier:
    - 其他:高考、人生意义、价值观、道德、信仰、宗教、爱情、全球、全国、亚洲、发展中、欧洲、美洲、东亚、东南亚、南美、非洲、大洋
    - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
 
-
 **重要提示:**
 - 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
 - 只要不是涉及高速公路服务区业务数据的问题都应分类为CHAT
@@ -480,8 +487,8 @@ class QuestionClassifier:
 """
             
             # 专业的系统提示词
-            system_prompt = """你是一个专业的业务问题分类助手,专门负责高速公路服务区管理系统的问题分类。你具有以下特长:
-1. 深度理解高速公路服务区业务领域和数据范围
+            system_prompt = """你是一个专业的业务问题分类助手。你具有以下特长:
+1. 深度理解业务领域和数据范围
 2. 准确区分业务数据查询需求和一般性问题  
 3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
 4. 对边界情况能够给出合理的置信度评估
@@ -497,6 +504,15 @@ class QuestionClassifier:
             # 解析响应
             return self._parse_llm_response(response)
             
+        except (FileNotFoundError, RuntimeError) as e:
+            # 业务上下文加载失败,返回错误状态
+            print(f"[ERROR] LLM分类失败,业务上下文不可用: {str(e)}")
+            return ClassificationResult(
+                question_type="CHAT",  # 失败时默认为CHAT,更安全
+                confidence=0.1,  # 很低的置信度表示分类不可靠
+                reason=f"业务上下文加载失败,无法进行准确分类: {str(e)}",
+                method="llm_context_error"
+            )
         except Exception as e:
             print(f"[WARNING] 增强LLM分类失败: {str(e)}")
             return ClassificationResult(

+ 17 - 17
agent/tools/db_query_decision_prompt.txt

@@ -1,20 +1,20 @@
 === 数据库业务范围 ===
-当前数据库存储的是高速公路服务区管理系统的相关数据,主要涉及以下业务主题,包含以下业务数据:
+本系统是高速公路服务区商业管理系统,包含以下业务数据:
+
 核心业务实体:
-- 服务类型:服务类型相关的业务信息
-- 运营状态:运营状态相关的业务信息
-- 支付方式:微信、支付宝、现金等支付方式
-- 服务区:服务区基础信息、位置、状态等
-- 公司:公司相关的业务信息
-- 业务类型:业务类型相关的业务信息
-- 路段:路段相关的业务信息
-- 区域:区域相关的业务信息
+- 服务区(bss_service_area):服务区基础信息、位置、状态,如"鄱阳湖服务区"、"信丰西服务区"
+- 档口/商铺(bss_branch):档口信息、品类(餐饮/小吃/便利店)、品牌,如"驿美餐饮"、"加水机"
+- 营业数据(bss_business_day_data):每日支付金额、订单数量,包含微信、支付宝、现金等支付方式
+- 车流量(bss_car_day_count):按车型统计的日流量数据,包含客车、货车、过境、危化品等
+- 公司信息(bss_company):服务区管理公司,如"驿美运营公司"
+
 关键业务指标:
-- 支付方式占比:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)等
-- 车流转化率:车流转化率相关的分析指标
-- 路段流量对比:不同维度的横向比较分析
-- 服务区排名:服务区排名相关的分析指标
-- 日营收总额:支付金额、订单数量、营业额统计等
-- 区域效能排名:区域效能排名相关的分析指标
-- 公司营收排名:支付金额、订单数量、营业额统计等
-- 档口偏好度:档口偏好度相关的分析指标
+- 支付方式:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)、行吧支付(xs)、金豆支付(jd)
+- 营业数据:支付金额、订单数量、营业额、收入统计
+- 车流统计:按车型(客车/货车/过境/危化品/城际)的流量分析
+- 经营分析:餐饮、小吃、便利店、整体租赁等品类收入
+- 地理分区:北区、南区、西区、东区、两区
+
+高速线路:
+- 线路信息:大广、昌金、昌栗等高速线路
+- 路段管理:按线路统计服务区分布

+ 0 - 16
agent/tools/sql_generation.py

@@ -58,22 +58,6 @@ def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str,
                 "can_retry": True
             }
         
-        # 检查是否返回了错误信息而非SQL
-        error_indicators = [
-            "insufficient context", "无法生成", "sorry", "cannot generate",
-            "not enough information", "unclear", "unable to"
-        ]
-        
-        if any(indicator in sql_clean.lower() for indicator in error_indicators):
-            # 这是解释性文本(已在base_llm_chat.py中处理thinking内容)
-            return {
-                "success": False,
-                "sql": None,
-                "error": sql_clean,
-                "error_type": "llm_explanation",
-                "can_retry": False
-            }
-        
         print(f"[TOOL:generate_sql] 成功生成SQL: {sql}")
         return {
             "success": True,

+ 3 - 2
customllm/base_llm_chat.py

@@ -311,9 +311,10 @@ class BaseLLMChat(VannaBase, ABC):
             
             # 检查是否包含错误提示信息
             error_indicators = [
-                "insufficient context", "无法生成", "sorry", "cannot", "不能",
+                "insufficient context", "无法生成", "sorry", "cannot generate", "cannot", "不能",
                 "no relevant", "no suitable", "unable to", "无法", "抱歉",
-                "i don't have", "i cannot", "没有相关", "找不到", "不存在"
+                "i don't have", "i cannot", "没有相关", "找不到", "不存在", "上下文不足",
+                "没有直接存储", "无法直接查询", "没有存储", "not enough information", "unclear"
             ]
             
             for indicator in error_indicators: