Browse Source

修正返回结果的问题,准备重新规范化返回的结果。

wangxq 2 tuần trước cách đây
mục cha
commit
58703f5215
4 tập tin đã thay đổi với 233 bổ sung107 xóa
  1. 23 8
      agent/citu_agent.py
  2. 205 94
      agent/classifier.py
  3. 3 3
      agent/config.py
  4. 2 2
      app_config.py

+ 23 - 8
agent/citu_agent.py

@@ -119,7 +119,7 @@ class CituLangGraphAgent:
             if error_type == "llm_explanation":
                 # LLM返回了解释性文本,直接作为最终答案
                 explanation = sql_result.get("error", "")
-                state["summary"] = explanation + " 请尝试提问其它问题。"
+                state["chat_response"] = explanation + " 请尝试提问其它问题。"
                 state["current_step"] = "database_completed"
                 state["execution_path"].append("agent_database")
                 print(f"[DATABASE_AGENT] 返回LLM解释性答案: {explanation}")
@@ -129,7 +129,7 @@ class CituLangGraphAgent:
             from agent.utils import _is_valid_sql_format
             if not _is_valid_sql_format(sql):
                 # 内容看起来不是SQL,当作解释性响应处理
-                state["summary"] = sql + " 请尝试提问其它问题。"
+                state["chat_response"] = sql + " 请尝试提问其它问题。"
                 state["current_step"] = "database_completed"  
                 state["execution_path"].append("agent_database")
                 print(f"[DATABASE_AGENT] 内容不是有效SQL,当作解释返回: {sql}")
@@ -256,14 +256,29 @@ class CituLangGraphAgent:
             
             elif state["question_type"] == "DATABASE":
                 # 数据库查询类型
-                if state.get("summary"):
-                    # 有摘要的情况(包括解释性响应和完整查询结果
+                if state.get("chat_response"):
+                    # SQL生成失败的解释性响应(不受ENABLE_RESULT_SUMMARY配置影响
                     state["final_response"] = {
                         "success": True,
-                        "response": state["summary"],
+                        "response": state["chat_response"],
                         "type": "DATABASE",
                         "sql": state.get("sql"),
-                        "data_result": state.get("data_result"),  # 可能为None(解释性响应)
+                        "data_result": state.get("data_result"),
+                        "execution_path": state["execution_path"],
+                        "classification_info": {
+                            "confidence": state["classification_confidence"],
+                            "reason": state["classification_reason"],
+                            "method": state["classification_method"]
+                        }
+                    }
+                elif state.get("summary"):
+                    # 正常的数据库查询结果,有摘要的情况
+                    # 不将summary复制到response,让response保持为空
+                    state["final_response"] = {
+                        "success": True,
+                        "type": "DATABASE",
+                        "sql": state.get("sql"),
+                        "data_result": state.get("data_result"),
                         "summary": state["summary"],
                         "execution_path": state["execution_path"],
                         "classification_info": {
@@ -277,10 +292,10 @@ class CituLangGraphAgent:
                     data_result = state.get("data_result")
                     row_count = data_result.get("row_count", 0)
                     
-                    # 构建基本响应,不包含summary字段
+                    # 构建基本响应,不包含summary字段和response字段
+                    # 用户应该直接从data_result.columns和data_result.rows获取数据
                     state["final_response"] = {
                         "success": True,
-                        "response": f"查询执行完成,共返回 {row_count} 条记录。",
                         "type": "DATABASE",
                         "sql": state.get("sql"),
                         "data_result": data_result,

+ 205 - 94
agent/classifier.py

@@ -12,180 +12,289 @@ class ClassificationResult:
 
 class QuestionClassifier:
     """
-    多策略融合的问题分类器
-    策略:规则优先 + LLM fallback
+    增强版问题分类器:基于高速公路服务区业务上下文的智能分类
     """
     
     def __init__(self):
-    # 从配置文件加载阈值参数
+        # 从配置文件加载阈值参数
         try:
             from agent.config import get_current_config, get_nested_config
             config = get_current_config()
-            self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.8)
+            self.high_confidence_threshold = get_nested_config(config, "classification.high_confidence_threshold", 0.7)
             self.low_confidence_threshold = get_nested_config(config, "classification.low_confidence_threshold", 0.4)
             self.max_confidence = get_nested_config(config, "classification.max_confidence", 0.9)
-            self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.5)
-            self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.1)
+            self.base_confidence = get_nested_config(config, "classification.base_confidence", 0.4)
+            self.confidence_increment = get_nested_config(config, "classification.confidence_increment", 0.08)
             self.llm_fallback_confidence = get_nested_config(config, "classification.llm_fallback_confidence", 0.5)
             self.uncertain_confidence = get_nested_config(config, "classification.uncertain_confidence", 0.2)
             print("[CLASSIFIER] 从配置文件加载分类器参数完成")
         except ImportError:
-            # 配置文件不可用时的默认值
-            self.high_confidence_threshold = 0.8
+            self.high_confidence_threshold = 0.7
             self.low_confidence_threshold = 0.4
             self.max_confidence = 0.9
-            self.base_confidence = 0.5
-            self.confidence_increment = 0.1
+            self.base_confidence = 0.4
+            self.confidence_increment = 0.08
             self.llm_fallback_confidence = 0.5
             self.uncertain_confidence = 0.2
             print("[CLASSIFIER] 配置文件不可用,使用默认分类器参数")
         
-        # 移除了 LLM 实例存储,现在使用 Vanna 实例
-        
-        self.db_keywords = {
-            "数据类": [
-                "收入", "销量", "数量", "平均", "总计", "统计", "合计", "累计",
-                "营业额", "利润", "成本", "费用", "金额", "价格", "单价", "服务区", "多少个"
+        # 基于高速公路服务区业务的精准关键词
+        self.strong_business_keywords = {
+            "核心业务实体": [
+                "服务区", "档口", "商铺", "收费站", "高速公路",
+                "驿美", "驿购",  # 业务系统名称
+                "北区", "南区", "西区", "东区", "两区",  # 物理分区
+                "停车区"
+            ],
+            "支付业务": [
+                "微信支付", "支付宝支付", "现金支付", "行吧支付", "金豆支付",
+                "支付金额", "订单数量", "营业额", "收入", "营业收入",
+                "微信", "支付宝", "现金", "行吧", "金豆",  # 简化形式
+                "wx", "zfb", "rmb", "xs", "jd"  # 系统字段名
             ],
-            "分析类": [
-                "分组", "排行", "排名", "增长率", "趋势", "对比", "比较", "占比",
-                "百分比", "比例", "环比", "同比", "最大", "最小", "最高", "最低"
+            "经营品类": [
+                "餐饮", "小吃", "便利店", "整体租赁",
+                "驿美餐饮", "品牌", "经营品类", "商业品类"
             ],
-            "时间类": [
-                "今天", "昨天", "本月", "上月", "去年", "季度", "年度", "月份",
-                "本年", "上年", "本周", "上周", "近期", "最近"
+            "车流业务": [
+                "车流量", "车辆数量", "客车", "货车", 
+                "过境", "危化品", "城际", "车辆统计",
+                "流量统计", "车型分布"
             ],
-            "业务类": [
-                "客户", "订单", "产品", "商品", "用户", "会员", "供应商", "库存",
-                "部门", "员工", "项目", "合同", "发票", "账单"
+            "地理路线": [
+                "大广", "昌金", "昌栗", "线路", "路段", "路线",
+                "高速线路", "公路线路"
             ]
         }
         
-        # SQL关键词
+        # 查询意图词(辅助判断)
+        self.query_intent_keywords = [
+            "统计", "查询", "分析", "排行", "排名",
+            "报表", "报告", "汇总", "计算", "对比",
+            "趋势", "占比", "百分比", "比例",
+            "最大", "最小", "最高", "最低", "平均",
+            "总计", "合计", "累计", "求和", "求平均",
+            "生成", "导出", "显示", "列出"
+        ]
+        
+        # 非业务实体词(包含则倾向CHAT)
+        self.non_business_keywords = [
+            # 农产品/食物
+            "荔枝", "苹果", "西瓜", "水果", "蔬菜", "大米", "小麦",
+            "橙子", "香蕉", "葡萄", "草莓", "樱桃", "桃子", "梨",
+            
+            # 技术概念  
+            "人工智能", "机器学习", "编程", "算法", "深度学习",
+            "AI", "神经网络", "模型训练", "数据挖掘",
+            
+            # 身份询问
+            "你是谁", "你是什么", "你叫什么", "你的名字", 
+            "什么模型", "大模型", "AI助手", "助手", "机器人",
+            
+            # 天气相关
+            "天气", "气温", "下雨", "晴天", "阴天", "温度",
+            "天气预报", "气候", "降雨", "雪天",
+            
+            # 其他生活常识
+            "怎么做饭", "如何减肥", "健康", "医疗", "病症",
+            "历史", "地理", "文学", "电影", "音乐", "体育",
+            "娱乐", "游戏", "小说", "新闻", "政治"
+        ]
+        
+        # SQL关键词(技术层面的数据库操作)
         self.sql_patterns = [
             r"\b(select|from|where|group by|order by|having|join)\b",
-            r"\b(查询|统计|汇总|计算|分析|有多少)\b",
-            r"\b(表|字段|数据库)\b"
+            r"\b(数据库|表名|字段名|SQL|sql)\b"
         ]
         
-        # 聊天关键词
+        # 聊天关键词(平台功能和帮助)
         self.chat_keywords = [
             "你好", "谢谢", "再见", "怎么样", "如何", "为什么", "什么是",
-            "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能"
+            "介绍", "解释", "说明", "帮助", "操作", "使用方法", "功能",
+            "教程", "指南", "手册"
         ]
     
     def classify(self, question: str) -> ClassificationResult:
         """
-        主分类方法:规则优先 + LLM fallback
+        主分类方法:规则预筛选 + 增强LLM分类
         """
-        # 第一步:规则分类
+        # 第一步:规则预筛选
         rule_result = self._rule_based_classify(question)
         
+        # 如果规则分类有高置信度,直接使用
         if rule_result.confidence >= self.high_confidence_threshold:
             return rule_result
         
-        # 第二步:LLM分类(针对不确定的情况)
-        if rule_result.confidence <= self.low_confidence_threshold:
-            llm_result = self._llm_classify(question)
-            
-            # 如果LLM也不确定,返回不确定状态
-            if llm_result.confidence <= self.low_confidence_threshold:
-                return ClassificationResult(
-                    question_type="UNCERTAIN",
-                    confidence=max(rule_result.confidence, llm_result.confidence),
-                    reason=f"规则和LLM都不确定: {rule_result.reason} | {llm_result.reason}",
-                    method="hybrid_uncertain"
-                )
-            
-            return llm_result
+        # 第二步:使用增强的LLM分类
+        llm_result = self._enhanced_llm_classify(question)
         
-        return rule_result
+        # 选择置信度更高的结果
+        if llm_result.confidence > rule_result.confidence:
+            return llm_result
+        else:
+            return rule_result
     
     def _rule_based_classify(self, question: str) -> ClassificationResult:
-        """基于规则的分类"""
+        """基于规则的预分类"""
         question_lower = question.lower()
         
-        # 检查数据库相关关键词
-        db_score = 0
-        matched_keywords = []
+        # 检查非业务实体词
+        non_business_matched = []
+        for keyword in self.non_business_keywords:
+            if keyword in question_lower:
+                non_business_matched.append(keyword)
+        
+        # 如果包含非业务实体词,直接分类为CHAT
+        if non_business_matched:
+            return ClassificationResult(
+                question_type="CHAT",
+                confidence=0.85,
+                reason=f"包含非业务实体词: {non_business_matched}",
+                method="rule_based_non_business"
+            )
+        
+        # 检查强业务关键词
+        business_score = 0
+        business_matched = []
         
-        for category, keywords in self.db_keywords.items():
+        for category, keywords in self.strong_business_keywords.items():
             for keyword in keywords:
                 if keyword in question_lower:
-                    db_score += 1
-                    matched_keywords.append(f"{category}:{keyword}")
+                    business_score += 2  # 业务实体词权重更高
+                    business_matched.append(f"{category}:{keyword}")
+        
+        # 检查查询意图词
+        intent_score = 0
+        intent_matched = []
+        for keyword in self.query_intent_keywords:
+            if keyword in question_lower:
+                intent_score += 1
+                intent_matched.append(keyword)
         
         # 检查SQL模式
         sql_patterns_matched = []
         for pattern in self.sql_patterns:
             if re.search(pattern, question_lower, re.IGNORECASE):
-                db_score += 2  # SQL模式权重更高
+                business_score += 3  # SQL模式权重最
                 sql_patterns_matched.append(pattern)
         
         # 检查聊天关键词
         chat_score = 0
-        chat_keywords_matched = []
+        chat_matched = []
         for keyword in self.chat_keywords:
             if keyword in question_lower:
                 chat_score += 1
-                chat_keywords_matched.append(keyword)
+                chat_matched.append(keyword)
         
-        # 计算置信度和分类
-        total_score = db_score + chat_score
+        # 分类决策逻辑
+        total_business_score = business_score + intent_score
         
-        if db_score > chat_score and db_score >= 1:
-            confidence = min(self.max_confidence, self.base_confidence + (db_score * self.confidence_increment))
+        # 强业务特征:包含业务实体 + 查询意图
+        if business_score >= 2 and intent_score >= 1:
+            confidence = min(self.max_confidence, 0.8 + (total_business_score * 0.05))
             return ClassificationResult(
                 question_type="DATABASE",
                 confidence=confidence,
-                reason=f"匹配数据库关键词: {matched_keywords}, SQL模式: {sql_patterns_matched}",
-                method="rule_based"
+                reason=f"强业务特征 - 业务实体: {business_matched}, 查询意图: {intent_matched}, SQL: {sql_patterns_matched}",
+                method="rule_based_strong_business"
+            )
+        
+        # 中等业务特征:包含多个业务实体词
+        elif business_score >= 4:
+            confidence = min(self.max_confidence, 0.7 + (business_score * 0.03))
+            return ClassificationResult(
+                question_type="DATABASE", 
+                confidence=confidence,
+                reason=f"中等业务特征 - 业务实体: {business_matched}",
+                method="rule_based_medium_business"
             )
-        elif chat_score > db_score and chat_score >= 1:
+        
+        # 聊天特征
+        elif chat_score >= 1 and business_score == 0:
             confidence = min(self.max_confidence, self.base_confidence + (chat_score * self.confidence_increment))
             return ClassificationResult(
                 question_type="CHAT",
                 confidence=confidence,
-                reason=f"匹配聊天关键词: {chat_keywords_matched}",
-                method="rule_based"
+                reason=f"聊天特征: {chat_matched}",
+                method="rule_based_chat"
             )
+        
+        # 不确定情况
         else:
-            # 没有明确匹配
             return ClassificationResult(
                 question_type="UNCERTAIN",
                 confidence=self.uncertain_confidence,
-                reason="没有匹配到明确的关键词模式",
-                method="rule_based"
+                reason=f"规则分类不确定 - 业务分:{business_score}, 意图分:{intent_score}, 聊天分:{chat_score}",
+                method="rule_based_uncertain"
             )
     
-
-    def _llm_classify(self, question: str) -> ClassificationResult:
-        """基于LLM的分类"""
+    def _enhanced_llm_classify(self, question: str) -> ClassificationResult:
+        """增强的LLM分类:包含详细的业务上下文"""
         try:
-            # 使用 Vanna 实例进行分类
             from common.vanna_instance import get_vanna_instance
             vn = get_vanna_instance()
             
-            # 分类提示词
+            # 构建包含业务上下文的分类提示词
             classification_prompt = f"""
-    请判断以下问题是否需要查询数据库。
+请判断以下用户问题是否需要查询我们的高速公路服务区管理数据库。
 
-    问题: {question}
+用户问题:{question}
 
-    判断标准:
-    1. 如果问题涉及数据查询、统计、分析、报表等,返回 "DATABASE"
-    2. 如果问题是一般性咨询、概念解释、操作指导、闲聊等,返回 "CHAT"
+=== 数据库业务范围 ===
+本系统是高速公路服务区商业管理系统,包含以下业务数据:
 
-    请只返回 "DATABASE" 或 "CHAT",并在下一行简要说明理由。
+核心业务实体:
+- 服务区(bss_service_area):服务区基础信息、位置、状态,如"鄱阳湖服务区"、"信丰西服务区"
+- 档口/商铺(bss_branch):档口信息、品类(餐饮/小吃/便利店)、品牌,如"驿美餐饮"、"加水机"
+- 营业数据(bss_business_day_data):每日支付金额、订单数量,包含微信、支付宝、现金等支付方式
+- 车流量(bss_car_day_count):按车型统计的日流量数据,包含客车、货车、过境、危化品等
+- 公司信息(bss_company):服务区管理公司,如"驿美运营公司"
 
-    格式:
-    分类: [DATABASE/CHAT]
-    理由: [简要说明]
-    置信度: [0.0-1.0之间的数字]
-    """
+关键业务指标:
+- 支付方式:微信支付(wx)、支付宝支付(zfb)、现金支付(rmb)、行吧支付(xs)、金豆支付(jd)
+- 营业数据:支付金额、订单数量、营业额、收入统计
+- 车流统计:按车型(客车/货车/过境/危化品/城际)的流量分析
+- 经营分析:餐饮、小吃、便利店、整体租赁等品类收入
+- 地理分区:北区、南区、西区、东区、两区
+
+高速线路:
+- 线路信息:大广、昌金、昌栗等高速线路
+- 路段管理:按线路统计服务区分布
+
+=== 判断标准 ===
+1. **DATABASE类型** - 需要查询数据库:
+   - 涉及上述业务实体和指标的查询、统计、分析、报表
+   - 包含业务相关的时间查询,如"本月服务区营业额"、"上月档口收入"
+   - 例如:"本月营业额统计"、"档口收入排行"、"车流量分析"、"支付方式占比"
+
+2. **CHAT类型** - 不需要查询数据库:
+   - 生活常识:水果蔬菜上市时间、动植物知识、天气等
+   - 身份询问:你是谁、什么模型、AI助手等
+   - 技术概念:人工智能、编程、算法等
+   - 平台使用:功能介绍、操作帮助、使用教程等
+   - 例如:"荔枝几月份上市"、"今天天气如何"、"你是什么AI"、"怎么使用平台"
+
+**重要提示:**
+- 只有涉及高速公路服务区业务数据的问题才分类为DATABASE
+- 即使包含时间词汇(如"月份"、"时间"),也要看是否与我们的业务数据相关
+- 农产品上市时间、生活常识等都应分类为CHAT
+
+请基于问题与我们高速公路服务区业务数据的相关性来分类。
+
+格式:
+分类: [DATABASE/CHAT]
+理由: [详细说明问题与业务数据的相关性,具体分析涉及哪些业务实体或为什么不相关]
+置信度: [0.0-1.0之间的数字]
+"""
             
-            # 分类专用的系统提示词
-            system_prompt = "你是一个专业的问题分类助手,能准确判断问题类型。请严格按照要求的格式返回分类结果。"
+            # 专业的系统提示词
+            system_prompt = """你是一个专业的业务问题分类助手,专门负责高速公路服务区管理系统的问题分类。你具有以下特长:
+1. 深度理解高速公路服务区业务领域和数据范围
+2. 准确区分业务数据查询需求和一般性问题  
+3. 基于具体业务上下文进行精准分类,而不仅仅依赖关键词匹配
+4. 对边界情况能够给出合理的置信度评估
+
+请严格按照业务相关性进行分类,并提供详细的分类理由。"""
             
             # 使用 Vanna 实例的 chat_with_llm 方法
             response = vn.chat_with_llm(
@@ -197,20 +306,20 @@ class QuestionClassifier:
             return self._parse_llm_response(response)
             
         except Exception as e:
-            print(f"[WARNING] LLM分类失败: {str(e)}")
+            print(f"[WARNING] 增强LLM分类失败: {str(e)}")
             return ClassificationResult(
-                question_type="UNCERTAIN",
+                question_type="CHAT",  # 失败时默认为CHAT,更安全
                 confidence=self.llm_fallback_confidence,
-                reason=f"LLM分类异常: {str(e)}",
+                reason=f"LLM分类异常,默认为聊天: {str(e)}",
                 method="llm_error"
             )
-        
+    
     def _parse_llm_response(self, response: str) -> ClassificationResult:
         """解析LLM响应"""
         try:
             lines = response.strip().split('\n')
             
-            question_type = "UNCERTAIN"
+            question_type = "CHAT"  # 默认为CHAT
             reason = "LLM响应解析失败"
             confidence = self.llm_fallback_confidence
             
@@ -230,6 +339,8 @@ class QuestionClassifier:
                     try:
                         conf_str = line.split(":", 1)[1].strip()
                         confidence = float(conf_str)
+                        # 确保置信度在合理范围内
+                        confidence = max(0.0, min(1.0, confidence))
                     except:
                         confidence = self.llm_fallback_confidence
             
@@ -237,12 +348,12 @@ class QuestionClassifier:
                 question_type=question_type,
                 confidence=confidence,
                 reason=reason,
-                method="llm_based"
+                method="enhanced_llm"
             )
             
         except Exception as e:
             return ClassificationResult(
-                question_type="UNCERTAIN",
+                question_type="CHAT",  # 解析失败时默认为CHAT
                 confidence=self.llm_fallback_confidence,
                 reason=f"响应解析失败: {str(e)}",
                 method="llm_parse_error"

+ 3 - 3
agent/config.py

@@ -14,7 +14,7 @@ AGENT_CONFIG = {
     "classification": {
         # 高置信度阈值:当规则分类的置信度 >= 此值时,直接使用规则分类结果,不再调用LLM
         # 建议范围:0.7-0.9,过高可能错过需要LLM辅助的边界情况,过低会增加LLM调用成本
-        "high_confidence_threshold": 0.8,
+        "high_confidence_threshold": 0.7,
         
         # 低置信度阈值:当规则分类的置信度 <= 此值时,启用LLM二次分类进行辅助判断
         # 建议范围:0.2-0.5,过高会频繁调用LLM,过低可能错过需要LLM辅助的情况
@@ -26,11 +26,11 @@ AGENT_CONFIG = {
         
         # 基础置信度:规则分类的起始置信度,会根据匹配的关键词数量递增
         # 建议范围:0.3-0.6,这是匹配到1个关键词时的基础置信度
-        "base_confidence": 0.5,
+        "base_confidence": 0.4,
         
         # 置信度增量步长:每匹配一个额外关键词,置信度增加的数值
         # 建议范围:0.05-0.2,过大会导致置信度增长过快,过小则区分度不够
-        "confidence_increment": 0.1,
+        "confidence_increment": 0.08,
         
         # LLM分类失败时的默认置信度:当LLM调用异常或解析失败时使用
         # 建议范围:0.3-0.6,通常设为中等水平,避免过高或过低的错误影响

+ 2 - 2
app_config.py

@@ -43,7 +43,7 @@ API_QIANWEN_CONFIG = {
     "n_results": 6,
     "language": "Chinese",
     "stream": True,  # 是否使用流式模式
-    "enable_thinking": False  # 是否启用思考功能(要求stream=True)
+    "enable_thinking": True  # 是否启用思考功能(要求stream=True)
 }
 #qwen3-30b-a3b
 #qwen3-235b-a22b
@@ -143,7 +143,7 @@ ENABLE_RESULT_SUMMARY = True
 # True: 显示 <think></think> 内容
 # False: 隐藏 <think></think> 内容,只显示最终答案
 # 此参数影响:摘要生成、SQL生成解释性文本、API返回结果等所有输出内容
-DISPLAY_RESULT_THINKING = False
+DISPLAY_RESULT_THINKING = True
 
 # 是否启用向量查询结果得分阈值过滤
 # result = max((n + 1) // 2, 1)