Переглянути джерело

准备修改让 agent_sql_generation 无法生成SQL的时候转发至agent_chat.

wangxq 3 тижнів тому
батько
коміт
964e5daa7a

+ 30 - 9
agent/citu_agent.py

@@ -488,7 +488,7 @@ class CituLangGraphAgent:
             if enable_context_injection:
                 # TODO: 在这里可以添加真实的对话历史上下文
                 # 例如从Redis或其他存储中获取最近的对话记录
-                # context = get_conversation_history(state.get("session_id"))
+                # context = get_conversation_history(state.get("conversation_id"))
                 pass
             
             # 直接调用general_chat工具
@@ -646,6 +646,17 @@ class CituLangGraphAgent:
                 }
             
             self.logger.info("响应格式化完成")
+            
+            # 输出完整的 STATE 内容用于调试
+            import json
+            try:
+                # 创建一个可序列化的 state 副本
+                debug_state = dict(state)
+                self.logger.debug(f"format_response_node 完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
+            except Exception as debug_e:
+                self.logger.debug(f"STATE 序列化失败,使用简单输出: {debug_e}")
+                self.logger.debug(f"format_response_node STATE 内容: {state}")
+            
             return state
             
         except Exception as e:
@@ -656,6 +667,16 @@ class CituLangGraphAgent:
                 "error_code": 500,
                 "execution_path": state["execution_path"]
             }
+            
+            # 即使在异常情况下也输出 STATE 内容用于调试
+            import json
+            try:
+                debug_state = dict(state)
+                self.logger.debug(f"format_response_node 异常情况下的完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
+            except Exception as debug_e:
+                self.logger.debug(f"异常情况下 STATE 序列化失败: {debug_e}")
+                self.logger.debug(f"format_response_node 异常情况下的 STATE 内容: {state}")
+            
             return state
     
     def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
@@ -697,13 +718,13 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
     
-    async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
+    async def process_question(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
         """
         统一的问题处理入口
         
         Args:
             question: 用户问题
-            session_id: 会话ID
+            conversation_id: 对话ID
             context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
             routing_mode: 路由模式,可选,用于覆盖配置文件设置
             
@@ -722,14 +743,14 @@ class CituLangGraphAgent:
             workflow = self._create_workflow(routing_mode)
             
             # 初始化状态
-            initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
+            initial_state = self._create_initial_state(question, conversation_id, context_type, routing_mode)
             
             # 执行工作流
             final_state = await workflow.ainvoke(
                 initial_state,
                 config={
-                    "configurable": {"session_id": session_id}
-                } if session_id else None
+                    "configurable": {"conversation_id": conversation_id}
+                } if conversation_id else None
             )
             
             # 提取最终结果
@@ -748,7 +769,7 @@ class CituLangGraphAgent:
                 "execution_path": ["error"]
             }
     
-    def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
+    def _create_initial_state(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
         """创建初始状态 - 支持渐进式分类"""
         # 确定使用的路由模式
         if routing_mode:
@@ -763,7 +784,7 @@ class CituLangGraphAgent:
         return AgentState(
             # 输入信息
             question=question,
-            session_id=session_id,
+            conversation_id=conversation_id,
             
             # 上下文信息
             context_type=context_type,
@@ -1043,7 +1064,7 @@ class CituLangGraphAgent:
             
             if enable_full_test:
                 # 完整流程测试
-                test_result = await self.process_question(test_question, "health_check")
+                test_result = await self.process_question(test_question, conversation_id="health_check")
                 
                 return {
                     "status": "healthy" if test_result.get("success") else "degraded",

+ 1 - 1
agent/state.py

@@ -7,7 +7,7 @@ class AgentState(TypedDict):
     
     # 输入信息
     question: str
-    session_id: Optional[str]
+    conversation_id: Optional[str]
     
     # 上下文信息
     context_type: Optional[str]  # 上下文类型 ("DATABASE" 或 "CHAT")

+ 1 - 1
app_config.py

@@ -169,7 +169,7 @@ REDIS_PASSWORD = None
 
 # 缓存开关配置
 ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
-ENABLE_QUESTION_ANSWER_CACHE = False     # 是否启用问答结果缓存
+ENABLE_QUESTION_ANSWER_CACHE = True     # 是否启用问答结果缓存
 ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 
 # TTL配置(单位:秒)

+ 58 - 30
citu_app.py

@@ -9,7 +9,7 @@ from flask import request, jsonify
 import pandas as pd
 import common.result as result
 from datetime import datetime, timedelta
-from common.session_aware_cache import WebSessionAwareMemoryCache
+from common.session_aware_cache import ConversationAwareMemoryCache
 from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
 import re
 import chainlit as cl
@@ -41,13 +41,13 @@ MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DE
 
 vn = create_vanna_instance()
 
-# 创建带时间戳的缓存
-timestamped_cache = WebSessionAwareMemoryCache()
+# 创建对话感知的缓存
+conversation_cache = ConversationAwareMemoryCache()
 
 # 实例化 VannaFlaskApp,使用自定义缓存
 app = VannaFlaskApp(
     vn,
-    cache=timestamped_cache,  # 使用带时间戳的缓存
+    cache=conversation_cache,  # 使用对话感知的缓存
     title="辞图智能数据问答平台",
     logo = "https://www.citupro.com/img/logo-black-2.png",
     subtitle="让 AI 为你写 SQL",
@@ -61,12 +61,13 @@ app = VannaFlaskApp(
 # 创建Redis对话管理器实例
 redis_conversation_manager = RedisConversationManager()
 
-# 修改ask接口,支持前端传递session_id
+# 修改ask接口,支持前端传递conversation_id
 @app.flask_app.route('/api/v0/ask', methods=['POST'])
 def ask_full():
     req = request.get_json(force=True)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)  # 前端传递的会话ID
+    conversation_id = req.get("conversation_id", None)  # 前端传递的对话ID
+    user_id = req.get("user_id", None)  # 前端传递的用户ID
     
     if not question:
         from common.result import bad_request_response
@@ -75,14 +76,13 @@ def ask_full():
             missing_params=["question"]
         )), 400
 
-    # 如果使用WebSessionAwareMemoryCache
-    if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
-        # 这里需要修改vanna的ask方法来支持传递session_id
-        # 或者预先调用generate_id来建立会话关联
-        conversation_id = app.cache.generate_id_with_browser_session(
-            question=question, 
-            browser_session_id=browser_session_id
-        )
+    # 如果没有传递user_id,使用默认值guest
+    if not user_id:
+        user_id = "guest"
+
+    # 如果前端没有传递conversation_id,则生成新的
+    if not conversation_id:
+        conversation_id = app.cache.generate_id(question=question, user_id=user_id)
 
     try:
         sql, df, _ = vn.ask(
@@ -144,8 +144,7 @@ def ask_full():
         response_data = {
             "sql": sql,
             "query_result": query_result,
-            "conversation_id": conversation_id if 'conversation_id' in locals() else None,
-            "session_id": browser_session_id
+            "conversation_id": conversation_id
         }
         
         # 添加摘要(如果启用且生成成功)
@@ -237,7 +236,8 @@ def ask_cached():
     """
     req = request.get_json(force=True)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
+    conversation_id = req.get("conversation_id", None)
+    user_id = req.get("user_id", None)
     
     if not question:
         from common.result import bad_request_response
@@ -246,15 +246,19 @@ def ask_cached():
             missing_params=["question"]
         )), 400
 
+    # 如果没有传递user_id,使用默认值guest
+    if not user_id:
+        user_id = "guest"
+
     try:
         # 生成conversation_id
         # 调试:查看generate_id的实际行为
         logger.debug(f"输入问题: '{question}'")
-        conversation_id = app.cache.generate_id(question=question)
+        conversation_id = app.cache.generate_id(question=question, user_id=user_id)
         logger.debug(f"生成的conversation_id: {conversation_id}")
         
         # 再次用相同问题测试
-        conversation_id2 = app.cache.generate_id(question=question)
+        conversation_id2 = app.cache.generate_id(question=question, user_id=user_id)
         logger.debug(f"再次生成的conversation_id: {conversation_id2}")
         logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
         
@@ -336,7 +340,6 @@ def ask_cached():
             "sql": sql,
             "query_result": query_result,
             "conversation_id": conversation_id,
-            "session_id": browser_session_id,
             "cached": cached_sql is not None  # 标识是否来自缓存
         }
         
@@ -449,11 +452,10 @@ def ask_agent():
     """
     req = request.get_json(force=True)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
+    conversation_id_input = req.get("conversation_id", None)
     
     # 新增参数解析
     user_id_input = req.get("user_id", None)
-    conversation_id_input = req.get("conversation_id", None)
     continue_conversation = req.get("continue_conversation", False)
     
     # 新增:路由模式参数解析和验证
@@ -477,9 +479,38 @@ def ask_agent():
         # 1. 获取登录用户ID(修正:在函数中获取session信息)
         login_user_id = session.get('user_id') if 'user_id' in session else None
         
-        # 2. 智能ID解析(修正:传入登录用户ID)
+        # 2. 用户ID和对话ID一致性校验
+        from common.session_aware_cache import ConversationAwareMemoryCache
+        
+        # 2.1 如果传递了conversation_id,从中解析user_id
+        extracted_user_id = None
+        if conversation_id_input:
+            extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
+            
+            # 如果同时传递了user_id和conversation_id,进行一致性校验
+            if user_id_input:
+                is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
+                    conversation_id_input, user_id_input
+                )
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=error_msg,
+                        invalid_params=["user_id", "conversation_id"]
+                    )), 400
+            
+            # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
+            elif not user_id_input and extracted_user_id:
+                user_id_input = extracted_user_id
+                logger.info(f"从conversation_id解析出user_id: {user_id_input}")
+        
+        # 2.2 如果没有传递user_id,使用默认值guest
+        if not user_id_input:
+            user_id_input = "guest"
+            logger.info("未传递user_id,使用默认值: guest")
+        
+        # 3. 智能ID解析(修正:传入登录用户ID)
         user_id = redis_conversation_manager.resolve_user_id(
-            user_id_input, browser_session_id, request.remote_addr, login_user_id
+            user_id_input, None, request.remote_addr, login_user_id
         )
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
             user_id, conversation_id_input, continue_conversation
@@ -552,10 +583,9 @@ def ask_agent():
                 sql=cached_answer.get("sql"),
                 records=cached_answer.get("query_result"),  # 修改:query_result改为records
                 summary=cached_answer.get("summary"),
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=cached_answer.get("execution_path", []),
                 classification_info=cached_answer.get("classification_info", {}),
-                conversation_id=conversation_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
@@ -605,7 +635,7 @@ def ask_agent():
         import asyncio
         agent_result = asyncio.run(agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
-            session_id=browser_session_id,
+            conversation_id=conversation_id,
             context_type=context_type,  # 传递上下文类型
             routing_mode=effective_routing_mode  # 新增:传递路由模式
         ))
@@ -662,10 +692,9 @@ def ask_agent():
                 sql=sql,
                 records=query_result,  # 修改:query_result改为records
                 summary=summary,
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=execution_path,
                 classification_info=classification_info,
-                conversation_id=conversation_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
@@ -685,7 +714,6 @@ def ask_agent():
                 response_text=error_message,
                 error_type="agent_processing_failed",
                 code=error_code,
-                session_id=browser_session_id,
                 conversation_id=conversation_id,
                 user_id=user_id
             )), error_code

+ 4 - 3
common/redis_conversation_manager.py

@@ -155,9 +155,10 @@ class RedisConversationManager:
     
     def create_conversation(self, user_id: str) -> str:
         """创建新对话"""
-        # 生成包含时间戳的conversation_id
-        timestamp = int(datetime.now().timestamp())
-        conversation_id = f"conv_{timestamp}_{uuid.uuid4().hex[:8]}"
+        # 生成包含时间戳的conversation_id,格式:{user_id}:YYYYMMDDHHMMSSsss
+        now = datetime.now()
+        timestamp = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
+        conversation_id = f"{user_id}:{timestamp}"
         
         if not self.is_available():
             return conversation_id  # Redis不可用时返回ID,但不存储

+ 6 - 6
common/result.py

@@ -101,14 +101,14 @@ def error_response(response_text, error_type=None, message=MessageTemplate.PROCE
 
 # ===== Ask Agent API 专用响应方法 =====
 
-def agent_success_response(response_type, session_id=None, execution_path=None, 
+def agent_success_response(response_type, conversation_id=None, execution_path=None, 
                           classification_info=None, agent_version="langgraph_v1", **kwargs):
     """
     Ask Agent API 成功响应格式
     
     Args:
         response_type: 响应类型 ("DATABASE" 或 "CHAT")
-        session_id: 会话ID
+        conversation_id: 对话ID
         execution_path: 执行路径
         classification_info: 分类信息
         agent_version: Agent版本
@@ -119,7 +119,7 @@ def agent_success_response(response_type, session_id=None, execution_path=None,
     """
     data = {
         "type": response_type,
-        "session_id": session_id,
+        "conversation_id": conversation_id,
         "execution_path": execution_path or [],
         "classification_info": classification_info or {},
         "agent_version": agent_version,
@@ -139,7 +139,7 @@ def agent_success_response(response_type, session_id=None, execution_path=None,
     }
 
 def agent_error_response(response_text, error_type=None, message=MessageTemplate.PROCESSING_FAILED,
-                        code=500, session_id=None, execution_path=None, 
+                        code=500, conversation_id=None, execution_path=None, 
                         classification_info=None, agent_version="langgraph_v1", **kwargs):
     """
     Ask Agent API 错误响应格式
@@ -149,7 +149,7 @@ def agent_error_response(response_text, error_type=None, message=MessageTemplate
         error_type: 错误类型标识
         message: 高层级描述信息
         code: HTTP状态码
-        session_id: 会话ID
+        conversation_id: 对话ID
         execution_path: 执行路径
         classification_info: 分类信息
         agent_version: Agent版本
@@ -160,7 +160,7 @@ def agent_error_response(response_text, error_type=None, message=MessageTemplate
     """
     data = {
         "response": response_text,
-        "session_id": session_id,
+        "conversation_id": conversation_id,
         "execution_path": execution_path or [],
         "classification_info": classification_info or {},
         "agent_version": agent_version,

+ 103 - 128
common/session_aware_cache.py

@@ -1,164 +1,139 @@
-# 修正后的 custom_cache.py
+# 简化后的对话感知缓存
 from datetime import datetime
 from vanna.flask import MemoryCache
 import uuid
 
-class SessionAwareMemoryCache(MemoryCache):
-    """区分会话(Session)和对话(Conversation)的缓存实现"""
+class ConversationAwareMemoryCache(MemoryCache):
+    """基于对话ID的简单时间感知缓存实现"""
     
     def __init__(self):
         super().__init__()
-        self.conversation_start_times = {}  # 每个对话的开始时间
-        self.session_info = {}  # 会话信息: {session_id: {'start_time': datetime, 'conversations': []}}
-        self.conversation_to_session = {}  # 对话ID到会话ID的映射
+        self.conversation_start_times = {}  # 每个对话的开始时间: {conversation_id: datetime}
     
-    def create_or_get_session_id(self, user_identifier=None):
-        """
-        创建或获取会话ID
-        在实际应用中,这可以通过以下方式确定:
-        1. HTTP请求中的session cookie
-        2. JWT token中的session信息
-        3. 前端传递的session_id
-        4. IP地址 + User-Agent的组合
-        """
-        # 简化实现:使用时间窗口来判断是否为同一会话
-        # 实际应用中应该从HTTP请求中获取session信息
-        current_time = datetime.now()
+    def generate_id(self, question: str = None, user_id: str = None) -> str:
+        """生成对话ID并记录时间,格式为 {user_id}:YYYYMMDDHHMMSSsss"""
+        # 如果没有传递user_id,使用默认值
+        if not user_id:
+            user_id = "guest"
         
-        # 检查是否有近期的会话(比如30分钟内)
-        for session_id, session_data in self.session_info.items():
-            last_activity = session_data.get('last_activity', session_data['start_time'])
-            if (current_time - last_activity).total_seconds() < 1800:  # 30分钟内
-                # 更新最后活动时间
-                session_data['last_activity'] = current_time
-                return session_id
+        # 生成时间戳:年月日时分秒毫秒格式
+        now = datetime.now()
+        timestamp = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
         
-        # 创建新会话
-        new_session_id = str(uuid.uuid4())
-        self.session_info[new_session_id] = {
-            'start_time': current_time,
-            'last_activity': current_time,
-            'conversations': []
-        }
-        return new_session_id
-    
-    def generate_id(self, question: str = None, session_id: str = None) -> str:
-        """重载generate_id方法,关联会话和对话"""
-        conversation_id = super().generate_id(question=question)
-        
-        # 确定会话ID
-        if not session_id:
-            session_id = self.create_or_get_session_id()
+        # 生成对话ID:{user_id}:{timestamp}
+        conversation_id = f"{user_id}:{timestamp}"
         
         # 记录对话开始时间
-        conversation_start_time = datetime.now()
-        self.conversation_start_times[conversation_id] = conversation_start_time
-        
-        # 建立对话与会话的关联
-        self.conversation_to_session[conversation_id] = session_id
-        self.session_info[session_id]['conversations'].append(conversation_id)
-        self.session_info[session_id]['last_activity'] = conversation_start_time
+        self.conversation_start_times[conversation_id] = now
         
         return conversation_id
     
-    def set(self, id: str, field: str, value, session_id: str = None):
+    def set(self, id: str, field: str, value, **kwargs):
         """重载set方法,确保时间信息正确"""
         # 如果这是新对话,初始化时间信息
         if id not in self.conversation_start_times:
-            if not session_id:
-                session_id = self.create_or_get_session_id()
-            
-            conversation_start_time = datetime.now()
-            self.conversation_start_times[id] = conversation_start_time
-            self.conversation_to_session[id] = session_id
-            self.session_info[session_id]['conversations'].append(id)
-            self.session_info[session_id]['last_activity'] = conversation_start_time
+            self.conversation_start_times[id] = datetime.now()
         
         # 调用父类的set方法
         super().set(id=id, field=field, value=value)
         
-        # 设置时间相关字段
-        if field != 'conversation_start_time' and field != 'session_start_time':
-            # 设置对话开始时间
+        # 自动设置对话开始时间字段
+        if field != 'conversation_start_time':
             super().set(id=id, field='conversation_start_time', 
                        value=self.conversation_start_times[id])
-            
-            # 设置会话开始时间
-            session_id = self.conversation_to_session.get(id)
-            if session_id and session_id in self.session_info:
-                super().set(id=id, field='session_start_time', 
-                           value=self.session_info[session_id]['start_time'])
-                super().set(id=id, field='session_id', value=session_id)
     
     def get_conversation_start_time(self, conversation_id: str) -> datetime:
         """获取对话开始时间"""
         return self.conversation_start_times.get(conversation_id)
     
-    def get_session_start_time(self, conversation_id: str) -> datetime:
-        """获取会话开始时间"""
-        session_id = self.conversation_to_session.get(conversation_id)
-        if session_id and session_id in self.session_info:
-            return self.session_info[session_id]['start_time']
-        return None
-    
-    def get_session_info(self, session_id: str = None, conversation_id: str = None):
-        """获取会话信息"""
-        if conversation_id:
-            session_id = self.conversation_to_session.get(conversation_id)
-        
-        if session_id and session_id in self.session_info:
-            session_data = self.session_info[session_id].copy()
-            session_data['conversation_count'] = len(session_data['conversations'])
-            if session_data['conversations']:
-                # 计算会话持续时间
-                duration = datetime.now() - session_data['start_time']
-                session_data['session_duration_seconds'] = duration.total_seconds()
-                session_data['session_duration_formatted'] = str(duration)
-            return session_data
+    def get_conversation_info(self, conversation_id: str):
+        """获取对话信息"""
+        start_time = self.get_conversation_start_time(conversation_id)
+        if start_time:
+            duration = datetime.now() - start_time
+            
+            # 从conversation_id解析user_id
+            user_id = "unknown"
+            if ":" in conversation_id:
+                user_id = conversation_id.split(":")[0]
+            
+            return {
+                'conversation_id': conversation_id,
+                'user_id': user_id,
+                'start_time': start_time,
+                'duration_seconds': duration.total_seconds(),
+                'duration_formatted': str(duration)
+            }
         return None
     
-    def get_all_sessions(self):
-        """获取所有话信息"""
+    def get_all_conversations(self):
+        """获取所有话信息"""
         result = {}
-        for session_id, session_data in self.session_info.items():
-            session_info = session_data.copy()
-            session_info['conversation_count'] = len(session_data['conversations'])
-            if session_data['conversations']:
-                duration = datetime.now() - session_data['start_time']
-                session_info['session_duration_seconds'] = duration.total_seconds()
-                session_info['session_duration_formatted'] = str(duration)
-            result[session_id] = session_info
+        for conversation_id, start_time in self.conversation_start_times.items():
+            duration = datetime.now() - start_time
+            
+            # 从conversation_id解析user_id
+            user_id = "unknown"
+            if ":" in conversation_id:
+                user_id = conversation_id.split(":")[0]
+                
+            result[conversation_id] = {
+                'user_id': user_id,
+                'start_time': start_time,
+                'duration_seconds': duration.total_seconds(),
+                'duration_formatted': str(duration)
+            }
         return result
 
+    @staticmethod
+    def parse_conversation_id(conversation_id: str):
+        """解析conversation_id,返回user_id和timestamp"""
+        if ":" not in conversation_id:
+            return None, None
+        
+        parts = conversation_id.split(":", 1)
+        user_id = parts[0]
+        timestamp_str = parts[1]
+        
+        try:
+            # 解析时间戳:YYYYMMDDHHMMSSsss
+            if len(timestamp_str) == 17:  # 20250722204550155
+                timestamp = datetime.strptime(timestamp_str[:14], "%Y%m%d%H%M%S")
+                # 添加毫秒
+                milliseconds = int(timestamp_str[14:])
+                timestamp = timestamp.replace(microsecond=milliseconds * 1000)
+                return user_id, timestamp
+        except ValueError:
+            pass
+        
+        return user_id, None
 
-# 升级版:支持前端传递会话ID
-class WebSessionAwareMemoryCache(SessionAwareMemoryCache):
-    """支持从前端获取会话ID的版本"""
-    
-    def __init__(self):
-        super().__init__()
-        self.browser_sessions = {}  # browser_session_id -> our_session_id
-    
-    def register_browser_session(self, browser_session_id: str, user_info: dict = None):
-        """注册浏览器会话"""
-        if browser_session_id not in self.browser_sessions:
-            our_session_id = str(uuid.uuid4())
-            self.browser_sessions[browser_session_id] = our_session_id
-            
-            self.session_info[our_session_id] = {
-                'start_time': datetime.now(),
-                'last_activity': datetime.now(),
-                'conversations': [],
-                'browser_session_id': browser_session_id,
-                'user_info': user_info or {}
-            }
-        return self.browser_sessions[browser_session_id]
-    
-    def generate_id_with_browser_session(self, question: str = None, browser_session_id: str = None) -> str:
-        """使用浏览器会话ID生成对话ID"""
-        if browser_session_id:
-            our_session_id = self.register_browser_session(browser_session_id)
-        else:
-            our_session_id = self.create_or_get_session_id()
+    @staticmethod
+    def extract_user_id(conversation_id: str) -> str:
+        """从conversation_id中提取user_id"""
+        if ":" not in conversation_id:
+            return "unknown"
+        return conversation_id.split(":", 1)[0]
+
+    @staticmethod
+    def validate_user_id_consistency(conversation_id: str, provided_user_id: str) -> tuple[bool, str]:
+        """
+        校验conversation_id中的user_id与提供的user_id是否一致
+        
+        Returns:
+            tuple: (is_valid, error_message)
+        """
+        if not conversation_id or not provided_user_id:
+            return True, ""  # 如果任一为空,跳过校验
         
-        return super().generate_id(question=question, session_id=our_session_id)
+        extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id)
+        
+        if extracted_user_id != provided_user_id:
+            return False, f"用户ID不匹配:conversation_id中的用户ID '{extracted_user_id}' 与提供的用户ID '{provided_user_id}' 不一致"
+        
+        return True, ""
+
+
+# 保持向后兼容的别名
+WebSessionAwareMemoryCache = ConversationAwareMemoryCache
+SessionAwareMemoryCache = ConversationAwareMemoryCache

+ 2 - 2
config/logging_config_windows.yaml

@@ -59,7 +59,7 @@ modules:
     level: DEBUG
     console:
       enabled: true
-      level: INFO
+      level: DEBUG
       format: "%(asctime)s [%(levelname)s] Agent: %(message)s"
     file:
       enabled: true
@@ -75,7 +75,7 @@ modules:
     level: DEBUG
     console:
       enabled: true
-      level: INFO
+      level: DEBUG
       format: "%(asctime)s [%(levelname)s] Vanna: %(message)s"
     file:
       enabled: true

+ 2 - 2
custompgvector/pgvector.py

@@ -663,7 +663,7 @@ class PG_VectorStore(VannaBase):
             
             # 检查原始查询结果是否为空
             if not results:
-                self.logger.warning(f"向量查询未找到任何相关的错误SQL示例,问题: {question}")
+                self.logger.warning(f"向量查询未找到任何相关的错误SQL示例")
 
             # 应用错误SQL特有的阈值过滤逻辑
             filtered_results = self._apply_error_sql_threshold_filter(results)
@@ -671,7 +671,7 @@ class PG_VectorStore(VannaBase):
             # 检查过滤后结果是否为空
             if results and not filtered_results:
                 self.logger.warning(f"向量查询找到了 {len(results)} 条错误SQL示例,但全部被阈值过滤掉.")
-                self.logger.warning(f"问题: {question}")
+                # self.logger.warning(f"问题: {question}")
 
             return filtered_results
             

+ 33 - 8
unified_api.py

@@ -711,7 +711,6 @@ def ask_agent():
     """支持对话上下文的ask_agent API"""
     req = request.get_json(force=True)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
     user_id_input = req.get("user_id", None)
     conversation_id_input = req.get("conversation_id", None)
     continue_conversation = req.get("continue_conversation", False)
@@ -735,9 +734,38 @@ def ask_agent():
         # 获取登录用户ID
         login_user_id = session.get('user_id') if 'user_id' in session else None
         
+        # 用户ID和对话ID一致性校验
+        from common.session_aware_cache import ConversationAwareMemoryCache
+        
+        # 如果传递了conversation_id,从中解析user_id
+        extracted_user_id = None
+        if conversation_id_input:
+            extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
+            
+            # 如果同时传递了user_id和conversation_id,进行一致性校验
+            if user_id_input:
+                is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
+                    conversation_id_input, user_id_input
+                )
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=error_msg,
+                        invalid_params=["user_id", "conversation_id"]
+                    )), 400
+            
+            # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
+            elif not user_id_input and extracted_user_id:
+                user_id_input = extracted_user_id
+                logger.info(f"从conversation_id解析出user_id: {user_id_input}")
+        
+        # 如果没有传递user_id,使用默认值guest
+        if not user_id_input:
+            user_id_input = "guest"
+            logger.info("未传递user_id,使用默认值: guest")
+        
         # 智能ID解析
         user_id = redis_conversation_manager.resolve_user_id(
-            user_id_input, browser_session_id, request.remote_addr, login_user_id
+            user_id_input, None, request.remote_addr, login_user_id
         )
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
             user_id, conversation_id_input, continue_conversation
@@ -810,10 +838,9 @@ def ask_agent():
                 sql=cached_answer.get("sql"),
                 records=cached_answer.get("query_result"),
                 summary=cached_answer.get("summary"),
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=cached_answer.get("execution_path", []),
                 classification_info=cached_answer.get("classification_info", {}),
-                conversation_id=conversation_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
@@ -863,7 +890,7 @@ def ask_agent():
         import asyncio
         agent_result = asyncio.run(agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
-            session_id=browser_session_id,
+            conversation_id=conversation_id,
             context_type=context_type,  # 传递上下文类型
             routing_mode=effective_routing_mode  # 新增:传递路由模式
         ))
@@ -913,10 +940,9 @@ def ask_agent():
                 sql=sql,
                 records=query_result,
                 summary=summary,
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=execution_path,
                 classification_info=classification_info,
-                conversation_id=conversation_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
@@ -936,7 +962,6 @@ def ask_agent():
                 response_text=error_message,
                 error_type="agent_processing_failed",
                 code=error_code,
-                session_id=browser_session_id,
                 conversation_id=conversation_id,
                 user_id=user_id
             )), error_code