Browse Source

准备修改让 agent_sql_generation 无法生成SQL的时候转发至agent_chat.

wangxq 3 tuần trước cách đây
mục cha
commit
964e5daa7a

+ 30 - 9
agent/citu_agent.py

@@ -488,7 +488,7 @@ class CituLangGraphAgent:
             if enable_context_injection:
             if enable_context_injection:
                 # TODO: 在这里可以添加真实的对话历史上下文
                 # TODO: 在这里可以添加真实的对话历史上下文
                 # 例如从Redis或其他存储中获取最近的对话记录
                 # 例如从Redis或其他存储中获取最近的对话记录
-                # context = get_conversation_history(state.get("session_id"))
+                # context = get_conversation_history(state.get("conversation_id"))
                 pass
                 pass
             
             
             # 直接调用general_chat工具
             # 直接调用general_chat工具
@@ -646,6 +646,17 @@ class CituLangGraphAgent:
                 }
                 }
             
             
             self.logger.info("响应格式化完成")
             self.logger.info("响应格式化完成")
+            
+            # 输出完整的 STATE 内容用于调试
+            import json
+            try:
+                # 创建一个可序列化的 state 副本
+                debug_state = dict(state)
+                self.logger.debug(f"format_response_node 完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
+            except Exception as debug_e:
+                self.logger.debug(f"STATE 序列化失败,使用简单输出: {debug_e}")
+                self.logger.debug(f"format_response_node STATE 内容: {state}")
+            
             return state
             return state
             
             
         except Exception as e:
         except Exception as e:
@@ -656,6 +667,16 @@ class CituLangGraphAgent:
                 "error_code": 500,
                 "error_code": 500,
                 "execution_path": state["execution_path"]
                 "execution_path": state["execution_path"]
             }
             }
+            
+            # 即使在异常情况下也输出 STATE 内容用于调试
+            import json
+            try:
+                debug_state = dict(state)
+                self.logger.debug(f"format_response_node 异常情况下的完整 STATE 内容: {json.dumps(debug_state, ensure_ascii=False, indent=2)}")
+            except Exception as debug_e:
+                self.logger.debug(f"异常情况下 STATE 序列化失败: {debug_e}")
+                self.logger.debug(f"format_response_node 异常情况下的 STATE 内容: {state}")
+            
             return state
             return state
     
     
     def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
     def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
@@ -697,13 +718,13 @@ class CituLangGraphAgent:
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             # 聊天Agent可以处理不确定的情况,并在必要时引导用户提供更多信息
             return "CHAT"
             return "CHAT"
     
     
-    async def process_question(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
+    async def process_question(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> Dict[str, Any]:
         """
         """
         统一的问题处理入口
         统一的问题处理入口
         
         
         Args:
         Args:
             question: 用户问题
             question: 用户问题
-            session_id: 会话ID
+            conversation_id: 对话ID
             context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
             context_type: 上下文类型 ("DATABASE" 或 "CHAT"),用于渐进式分类
             routing_mode: 路由模式,可选,用于覆盖配置文件设置
             routing_mode: 路由模式,可选,用于覆盖配置文件设置
             
             
@@ -722,14 +743,14 @@ class CituLangGraphAgent:
             workflow = self._create_workflow(routing_mode)
             workflow = self._create_workflow(routing_mode)
             
             
             # 初始化状态
             # 初始化状态
-            initial_state = self._create_initial_state(question, session_id, context_type, routing_mode)
+            initial_state = self._create_initial_state(question, conversation_id, context_type, routing_mode)
             
             
             # 执行工作流
             # 执行工作流
             final_state = await workflow.ainvoke(
             final_state = await workflow.ainvoke(
                 initial_state,
                 initial_state,
                 config={
                 config={
-                    "configurable": {"session_id": session_id}
-                } if session_id else None
+                    "configurable": {"conversation_id": conversation_id}
+                } if conversation_id else None
             )
             )
             
             
             # 提取最终结果
             # 提取最终结果
@@ -748,7 +769,7 @@ class CituLangGraphAgent:
                 "execution_path": ["error"]
                 "execution_path": ["error"]
             }
             }
     
     
-    def _create_initial_state(self, question: str, session_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
+    def _create_initial_state(self, question: str, conversation_id: str = None, context_type: str = None, routing_mode: str = None) -> AgentState:
         """创建初始状态 - 支持渐进式分类"""
         """创建初始状态 - 支持渐进式分类"""
         # 确定使用的路由模式
         # 确定使用的路由模式
         if routing_mode:
         if routing_mode:
@@ -763,7 +784,7 @@ class CituLangGraphAgent:
         return AgentState(
         return AgentState(
             # 输入信息
             # 输入信息
             question=question,
             question=question,
-            session_id=session_id,
+            conversation_id=conversation_id,
             
             
             # 上下文信息
             # 上下文信息
             context_type=context_type,
             context_type=context_type,
@@ -1043,7 +1064,7 @@ class CituLangGraphAgent:
             
             
             if enable_full_test:
             if enable_full_test:
                 # 完整流程测试
                 # 完整流程测试
-                test_result = await self.process_question(test_question, "health_check")
+                test_result = await self.process_question(test_question, conversation_id="health_check")
                 
                 
                 return {
                 return {
                     "status": "healthy" if test_result.get("success") else "degraded",
                     "status": "healthy" if test_result.get("success") else "degraded",

+ 1 - 1
agent/state.py

@@ -7,7 +7,7 @@ class AgentState(TypedDict):
     
     
     # 输入信息
     # 输入信息
     question: str
     question: str
-    session_id: Optional[str]
+    conversation_id: Optional[str]
     
     
     # 上下文信息
     # 上下文信息
     context_type: Optional[str]  # 上下文类型 ("DATABASE" 或 "CHAT")
     context_type: Optional[str]  # 上下文类型 ("DATABASE" 或 "CHAT")

+ 1 - 1
app_config.py

@@ -169,7 +169,7 @@ REDIS_PASSWORD = None
 
 
 # 缓存开关配置
 # 缓存开关配置
 ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
 ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
-ENABLE_QUESTION_ANSWER_CACHE = False     # 是否启用问答结果缓存
+ENABLE_QUESTION_ANSWER_CACHE = True     # 是否启用问答结果缓存
 ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 
 
 # TTL配置(单位:秒)
 # TTL配置(单位:秒)

+ 58 - 30
citu_app.py

@@ -9,7 +9,7 @@ from flask import request, jsonify
 import pandas as pd
 import pandas as pd
 import common.result as result
 import common.result as result
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from common.session_aware_cache import WebSessionAwareMemoryCache
+from common.session_aware_cache import ConversationAwareMemoryCache
 from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
 from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
 import re
 import re
 import chainlit as cl
 import chainlit as cl
@@ -41,13 +41,13 @@ MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DE
 
 
 vn = create_vanna_instance()
 vn = create_vanna_instance()
 
 
-# 创建带时间戳的缓存
-timestamped_cache = WebSessionAwareMemoryCache()
+# 创建对话感知的缓存
+conversation_cache = ConversationAwareMemoryCache()
 
 
 # 实例化 VannaFlaskApp,使用自定义缓存
 # 实例化 VannaFlaskApp,使用自定义缓存
 app = VannaFlaskApp(
 app = VannaFlaskApp(
     vn,
     vn,
-    cache=timestamped_cache,  # 使用带时间戳的缓存
+    cache=conversation_cache,  # 使用对话感知的缓存
     title="辞图智能数据问答平台",
     title="辞图智能数据问答平台",
     logo = "https://www.citupro.com/img/logo-black-2.png",
     logo = "https://www.citupro.com/img/logo-black-2.png",
     subtitle="让 AI 为你写 SQL",
     subtitle="让 AI 为你写 SQL",
@@ -61,12 +61,13 @@ app = VannaFlaskApp(
 # 创建Redis对话管理器实例
 # 创建Redis对话管理器实例
 redis_conversation_manager = RedisConversationManager()
 redis_conversation_manager = RedisConversationManager()
 
 
-# 修改ask接口,支持前端传递session_id
+# 修改ask接口,支持前端传递conversation_id
 @app.flask_app.route('/api/v0/ask', methods=['POST'])
 @app.flask_app.route('/api/v0/ask', methods=['POST'])
 def ask_full():
 def ask_full():
     req = request.get_json(force=True)
     req = request.get_json(force=True)
     question = req.get("question", None)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)  # 前端传递的会话ID
+    conversation_id = req.get("conversation_id", None)  # 前端传递的对话ID
+    user_id = req.get("user_id", None)  # 前端传递的用户ID
     
     
     if not question:
     if not question:
         from common.result import bad_request_response
         from common.result import bad_request_response
@@ -75,14 +76,13 @@ def ask_full():
             missing_params=["question"]
             missing_params=["question"]
         )), 400
         )), 400
 
 
-    # 如果使用WebSessionAwareMemoryCache
-    if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
-        # 这里需要修改vanna的ask方法来支持传递session_id
-        # 或者预先调用generate_id来建立会话关联
-        conversation_id = app.cache.generate_id_with_browser_session(
-            question=question, 
-            browser_session_id=browser_session_id
-        )
+    # 如果没有传递user_id,使用默认值guest
+    if not user_id:
+        user_id = "guest"
+
+    # 如果前端没有传递conversation_id,则生成新的
+    if not conversation_id:
+        conversation_id = app.cache.generate_id(question=question, user_id=user_id)
 
 
     try:
     try:
         sql, df, _ = vn.ask(
         sql, df, _ = vn.ask(
@@ -144,8 +144,7 @@ def ask_full():
         response_data = {
         response_data = {
             "sql": sql,
             "sql": sql,
             "query_result": query_result,
             "query_result": query_result,
-            "conversation_id": conversation_id if 'conversation_id' in locals() else None,
-            "session_id": browser_session_id
+            "conversation_id": conversation_id
         }
         }
         
         
         # 添加摘要(如果启用且生成成功)
         # 添加摘要(如果启用且生成成功)
@@ -237,7 +236,8 @@ def ask_cached():
     """
     """
     req = request.get_json(force=True)
     req = request.get_json(force=True)
     question = req.get("question", None)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
+    conversation_id = req.get("conversation_id", None)
+    user_id = req.get("user_id", None)
     
     
     if not question:
     if not question:
         from common.result import bad_request_response
         from common.result import bad_request_response
@@ -246,15 +246,19 @@ def ask_cached():
             missing_params=["question"]
             missing_params=["question"]
         )), 400
         )), 400
 
 
+    # 如果没有传递user_id,使用默认值guest
+    if not user_id:
+        user_id = "guest"
+
     try:
     try:
         # 生成conversation_id
         # 生成conversation_id
         # 调试:查看generate_id的实际行为
         # 调试:查看generate_id的实际行为
         logger.debug(f"输入问题: '{question}'")
         logger.debug(f"输入问题: '{question}'")
-        conversation_id = app.cache.generate_id(question=question)
+        conversation_id = app.cache.generate_id(question=question, user_id=user_id)
         logger.debug(f"生成的conversation_id: {conversation_id}")
         logger.debug(f"生成的conversation_id: {conversation_id}")
         
         
         # 再次用相同问题测试
         # 再次用相同问题测试
-        conversation_id2 = app.cache.generate_id(question=question)
+        conversation_id2 = app.cache.generate_id(question=question, user_id=user_id)
         logger.debug(f"再次生成的conversation_id: {conversation_id2}")
         logger.debug(f"再次生成的conversation_id: {conversation_id2}")
         logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
         logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
         
         
@@ -336,7 +340,6 @@ def ask_cached():
             "sql": sql,
             "sql": sql,
             "query_result": query_result,
             "query_result": query_result,
             "conversation_id": conversation_id,
             "conversation_id": conversation_id,
-            "session_id": browser_session_id,
             "cached": cached_sql is not None  # 标识是否来自缓存
             "cached": cached_sql is not None  # 标识是否来自缓存
         }
         }
         
         
@@ -449,11 +452,10 @@ def ask_agent():
     """
     """
     req = request.get_json(force=True)
     req = request.get_json(force=True)
     question = req.get("question", None)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
+    conversation_id_input = req.get("conversation_id", None)
     
     
     # 新增参数解析
     # 新增参数解析
     user_id_input = req.get("user_id", None)
     user_id_input = req.get("user_id", None)
-    conversation_id_input = req.get("conversation_id", None)
     continue_conversation = req.get("continue_conversation", False)
     continue_conversation = req.get("continue_conversation", False)
     
     
     # 新增:路由模式参数解析和验证
     # 新增:路由模式参数解析和验证
@@ -477,9 +479,38 @@ def ask_agent():
         # 1. 获取登录用户ID(修正:在函数中获取session信息)
         # 1. 获取登录用户ID(修正:在函数中获取session信息)
         login_user_id = session.get('user_id') if 'user_id' in session else None
         login_user_id = session.get('user_id') if 'user_id' in session else None
         
         
-        # 2. 智能ID解析(修正:传入登录用户ID)
+        # 2. 用户ID和对话ID一致性校验
+        from common.session_aware_cache import ConversationAwareMemoryCache
+        
+        # 2.1 如果传递了conversation_id,从中解析user_id
+        extracted_user_id = None
+        if conversation_id_input:
+            extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
+            
+            # 如果同时传递了user_id和conversation_id,进行一致性校验
+            if user_id_input:
+                is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
+                    conversation_id_input, user_id_input
+                )
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=error_msg,
+                        invalid_params=["user_id", "conversation_id"]
+                    )), 400
+            
+            # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
+            elif not user_id_input and extracted_user_id:
+                user_id_input = extracted_user_id
+                logger.info(f"从conversation_id解析出user_id: {user_id_input}")
+        
+        # 2.2 如果没有传递user_id,使用默认值guest
+        if not user_id_input:
+            user_id_input = "guest"
+            logger.info("未传递user_id,使用默认值: guest")
+        
+        # 3. 智能ID解析(修正:传入登录用户ID)
         user_id = redis_conversation_manager.resolve_user_id(
         user_id = redis_conversation_manager.resolve_user_id(
-            user_id_input, browser_session_id, request.remote_addr, login_user_id
+            user_id_input, None, request.remote_addr, login_user_id
         )
         )
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
             user_id, conversation_id_input, continue_conversation
             user_id, conversation_id_input, continue_conversation
@@ -552,10 +583,9 @@ def ask_agent():
                 sql=cached_answer.get("sql"),
                 sql=cached_answer.get("sql"),
                 records=cached_answer.get("query_result"),  # 修改:query_result改为records
                 records=cached_answer.get("query_result"),  # 修改:query_result改为records
                 summary=cached_answer.get("summary"),
                 summary=cached_answer.get("summary"),
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=cached_answer.get("execution_path", []),
                 execution_path=cached_answer.get("execution_path", []),
                 classification_info=cached_answer.get("classification_info", {}),
                 classification_info=cached_answer.get("classification_info", {}),
-                conversation_id=conversation_id,
                 user_id=user_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
                 context_used=bool(context),
@@ -605,7 +635,7 @@ def ask_agent():
         import asyncio
         import asyncio
         agent_result = asyncio.run(agent.process_question(
         agent_result = asyncio.run(agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
             question=enhanced_question,  # 使用增强后的问题
-            session_id=browser_session_id,
+            conversation_id=conversation_id,
             context_type=context_type,  # 传递上下文类型
             context_type=context_type,  # 传递上下文类型
             routing_mode=effective_routing_mode  # 新增:传递路由模式
             routing_mode=effective_routing_mode  # 新增:传递路由模式
         ))
         ))
@@ -662,10 +692,9 @@ def ask_agent():
                 sql=sql,
                 sql=sql,
                 records=query_result,  # 修改:query_result改为records
                 records=query_result,  # 修改:query_result改为records
                 summary=summary,
                 summary=summary,
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=execution_path,
                 execution_path=execution_path,
                 classification_info=classification_info,
                 classification_info=classification_info,
-                conversation_id=conversation_id,
                 user_id=user_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
                 context_used=bool(context),
@@ -685,7 +714,6 @@ def ask_agent():
                 response_text=error_message,
                 response_text=error_message,
                 error_type="agent_processing_failed",
                 error_type="agent_processing_failed",
                 code=error_code,
                 code=error_code,
-                session_id=browser_session_id,
                 conversation_id=conversation_id,
                 conversation_id=conversation_id,
                 user_id=user_id
                 user_id=user_id
             )), error_code
             )), error_code

+ 4 - 3
common/redis_conversation_manager.py

@@ -155,9 +155,10 @@ class RedisConversationManager:
     
     
     def create_conversation(self, user_id: str) -> str:
     def create_conversation(self, user_id: str) -> str:
         """创建新对话"""
         """创建新对话"""
-        # 生成包含时间戳的conversation_id
-        timestamp = int(datetime.now().timestamp())
-        conversation_id = f"conv_{timestamp}_{uuid.uuid4().hex[:8]}"
+        # 生成包含时间戳的conversation_id,格式:{user_id}:YYYYMMDDHHMMSSsss
+        now = datetime.now()
+        timestamp = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
+        conversation_id = f"{user_id}:{timestamp}"
         
         
         if not self.is_available():
         if not self.is_available():
             return conversation_id  # Redis不可用时返回ID,但不存储
             return conversation_id  # Redis不可用时返回ID,但不存储

+ 6 - 6
common/result.py

@@ -101,14 +101,14 @@ def error_response(response_text, error_type=None, message=MessageTemplate.PROCE
 
 
 # ===== Ask Agent API 专用响应方法 =====
 # ===== Ask Agent API 专用响应方法 =====
 
 
-def agent_success_response(response_type, session_id=None, execution_path=None, 
+def agent_success_response(response_type, conversation_id=None, execution_path=None, 
                           classification_info=None, agent_version="langgraph_v1", **kwargs):
                           classification_info=None, agent_version="langgraph_v1", **kwargs):
     """
     """
     Ask Agent API 成功响应格式
     Ask Agent API 成功响应格式
     
     
     Args:
     Args:
         response_type: 响应类型 ("DATABASE" 或 "CHAT")
         response_type: 响应类型 ("DATABASE" 或 "CHAT")
-        session_id: 会话ID
+        conversation_id: 对话ID
         execution_path: 执行路径
         execution_path: 执行路径
         classification_info: 分类信息
         classification_info: 分类信息
         agent_version: Agent版本
         agent_version: Agent版本
@@ -119,7 +119,7 @@ def agent_success_response(response_type, session_id=None, execution_path=None,
     """
     """
     data = {
     data = {
         "type": response_type,
         "type": response_type,
-        "session_id": session_id,
+        "conversation_id": conversation_id,
         "execution_path": execution_path or [],
         "execution_path": execution_path or [],
         "classification_info": classification_info or {},
         "classification_info": classification_info or {},
         "agent_version": agent_version,
         "agent_version": agent_version,
@@ -139,7 +139,7 @@ def agent_success_response(response_type, session_id=None, execution_path=None,
     }
     }
 
 
 def agent_error_response(response_text, error_type=None, message=MessageTemplate.PROCESSING_FAILED,
 def agent_error_response(response_text, error_type=None, message=MessageTemplate.PROCESSING_FAILED,
-                        code=500, session_id=None, execution_path=None, 
+                        code=500, conversation_id=None, execution_path=None, 
                         classification_info=None, agent_version="langgraph_v1", **kwargs):
                         classification_info=None, agent_version="langgraph_v1", **kwargs):
     """
     """
     Ask Agent API 错误响应格式
     Ask Agent API 错误响应格式
@@ -149,7 +149,7 @@ def agent_error_response(response_text, error_type=None, message=MessageTemplate
         error_type: 错误类型标识
         error_type: 错误类型标识
         message: 高层级描述信息
         message: 高层级描述信息
         code: HTTP状态码
         code: HTTP状态码
-        session_id: 会话ID
+        conversation_id: 对话ID
         execution_path: 执行路径
         execution_path: 执行路径
         classification_info: 分类信息
         classification_info: 分类信息
         agent_version: Agent版本
         agent_version: Agent版本
@@ -160,7 +160,7 @@ def agent_error_response(response_text, error_type=None, message=MessageTemplate
     """
     """
     data = {
     data = {
         "response": response_text,
         "response": response_text,
-        "session_id": session_id,
+        "conversation_id": conversation_id,
         "execution_path": execution_path or [],
         "execution_path": execution_path or [],
         "classification_info": classification_info or {},
         "classification_info": classification_info or {},
         "agent_version": agent_version,
         "agent_version": agent_version,

+ 103 - 128
common/session_aware_cache.py

@@ -1,164 +1,139 @@
-# 修正后的 custom_cache.py
+# 简化后的对话感知缓存
 from datetime import datetime
 from datetime import datetime
 from vanna.flask import MemoryCache
 from vanna.flask import MemoryCache
 import uuid
 import uuid
 
 
-class SessionAwareMemoryCache(MemoryCache):
-    """区分会话(Session)和对话(Conversation)的缓存实现"""
+class ConversationAwareMemoryCache(MemoryCache):
+    """基于对话ID的简单时间感知缓存实现"""
     
     
     def __init__(self):
     def __init__(self):
         super().__init__()
         super().__init__()
-        self.conversation_start_times = {}  # 每个对话的开始时间
-        self.session_info = {}  # 会话信息: {session_id: {'start_time': datetime, 'conversations': []}}
-        self.conversation_to_session = {}  # 对话ID到会话ID的映射
+        self.conversation_start_times = {}  # 每个对话的开始时间: {conversation_id: datetime}
     
     
-    def create_or_get_session_id(self, user_identifier=None):
-        """
-        创建或获取会话ID
-        在实际应用中,这可以通过以下方式确定:
-        1. HTTP请求中的session cookie
-        2. JWT token中的session信息
-        3. 前端传递的session_id
-        4. IP地址 + User-Agent的组合
-        """
-        # 简化实现:使用时间窗口来判断是否为同一会话
-        # 实际应用中应该从HTTP请求中获取session信息
-        current_time = datetime.now()
+    def generate_id(self, question: str = None, user_id: str = None) -> str:
+        """生成对话ID并记录时间,格式为 {user_id}:YYYYMMDDHHMMSSsss"""
+        # 如果没有传递user_id,使用默认值
+        if not user_id:
+            user_id = "guest"
         
         
-        # 检查是否有近期的会话(比如30分钟内)
-        for session_id, session_data in self.session_info.items():
-            last_activity = session_data.get('last_activity', session_data['start_time'])
-            if (current_time - last_activity).total_seconds() < 1800:  # 30分钟内
-                # 更新最后活动时间
-                session_data['last_activity'] = current_time
-                return session_id
+        # 生成时间戳:年月日时分秒毫秒格式
+        now = datetime.now()
+        timestamp = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
         
         
-        # 创建新会话
-        new_session_id = str(uuid.uuid4())
-        self.session_info[new_session_id] = {
-            'start_time': current_time,
-            'last_activity': current_time,
-            'conversations': []
-        }
-        return new_session_id
-    
-    def generate_id(self, question: str = None, session_id: str = None) -> str:
-        """重载generate_id方法,关联会话和对话"""
-        conversation_id = super().generate_id(question=question)
-        
-        # 确定会话ID
-        if not session_id:
-            session_id = self.create_or_get_session_id()
+        # 生成对话ID:{user_id}:{timestamp}
+        conversation_id = f"{user_id}:{timestamp}"
         
         
         # 记录对话开始时间
         # 记录对话开始时间
-        conversation_start_time = datetime.now()
-        self.conversation_start_times[conversation_id] = conversation_start_time
-        
-        # 建立对话与会话的关联
-        self.conversation_to_session[conversation_id] = session_id
-        self.session_info[session_id]['conversations'].append(conversation_id)
-        self.session_info[session_id]['last_activity'] = conversation_start_time
+        self.conversation_start_times[conversation_id] = now
         
         
         return conversation_id
         return conversation_id
     
     
-    def set(self, id: str, field: str, value, session_id: str = None):
+    def set(self, id: str, field: str, value, **kwargs):
         """重载set方法,确保时间信息正确"""
         """重载set方法,确保时间信息正确"""
         # 如果这是新对话,初始化时间信息
         # 如果这是新对话,初始化时间信息
         if id not in self.conversation_start_times:
         if id not in self.conversation_start_times:
-            if not session_id:
-                session_id = self.create_or_get_session_id()
-            
-            conversation_start_time = datetime.now()
-            self.conversation_start_times[id] = conversation_start_time
-            self.conversation_to_session[id] = session_id
-            self.session_info[session_id]['conversations'].append(id)
-            self.session_info[session_id]['last_activity'] = conversation_start_time
+            self.conversation_start_times[id] = datetime.now()
         
         
         # 调用父类的set方法
         # 调用父类的set方法
         super().set(id=id, field=field, value=value)
         super().set(id=id, field=field, value=value)
         
         
-        # 设置时间相关字段
-        if field != 'conversation_start_time' and field != 'session_start_time':
-            # 设置对话开始时间
+        # 自动设置对话开始时间字段
+        if field != 'conversation_start_time':
             super().set(id=id, field='conversation_start_time', 
             super().set(id=id, field='conversation_start_time', 
                        value=self.conversation_start_times[id])
                        value=self.conversation_start_times[id])
-            
-            # 设置会话开始时间
-            session_id = self.conversation_to_session.get(id)
-            if session_id and session_id in self.session_info:
-                super().set(id=id, field='session_start_time', 
-                           value=self.session_info[session_id]['start_time'])
-                super().set(id=id, field='session_id', value=session_id)
     
     
     def get_conversation_start_time(self, conversation_id: str) -> datetime:
     def get_conversation_start_time(self, conversation_id: str) -> datetime:
         """获取对话开始时间"""
         """获取对话开始时间"""
         return self.conversation_start_times.get(conversation_id)
         return self.conversation_start_times.get(conversation_id)
     
     
-    def get_session_start_time(self, conversation_id: str) -> datetime:
-        """获取会话开始时间"""
-        session_id = self.conversation_to_session.get(conversation_id)
-        if session_id and session_id in self.session_info:
-            return self.session_info[session_id]['start_time']
-        return None
-    
-    def get_session_info(self, session_id: str = None, conversation_id: str = None):
-        """获取会话信息"""
-        if conversation_id:
-            session_id = self.conversation_to_session.get(conversation_id)
-        
-        if session_id and session_id in self.session_info:
-            session_data = self.session_info[session_id].copy()
-            session_data['conversation_count'] = len(session_data['conversations'])
-            if session_data['conversations']:
-                # 计算会话持续时间
-                duration = datetime.now() - session_data['start_time']
-                session_data['session_duration_seconds'] = duration.total_seconds()
-                session_data['session_duration_formatted'] = str(duration)
-            return session_data
+    def get_conversation_info(self, conversation_id: str):
+        """获取对话信息"""
+        start_time = self.get_conversation_start_time(conversation_id)
+        if start_time:
+            duration = datetime.now() - start_time
+            
+            # 从conversation_id解析user_id
+            user_id = "unknown"
+            if ":" in conversation_id:
+                user_id = conversation_id.split(":")[0]
+            
+            return {
+                'conversation_id': conversation_id,
+                'user_id': user_id,
+                'start_time': start_time,
+                'duration_seconds': duration.total_seconds(),
+                'duration_formatted': str(duration)
+            }
         return None
         return None
     
     
-    def get_all_sessions(self):
-        """获取所有话信息"""
+    def get_all_conversations(self):
+        """获取所有话信息"""
         result = {}
         result = {}
-        for session_id, session_data in self.session_info.items():
-            session_info = session_data.copy()
-            session_info['conversation_count'] = len(session_data['conversations'])
-            if session_data['conversations']:
-                duration = datetime.now() - session_data['start_time']
-                session_info['session_duration_seconds'] = duration.total_seconds()
-                session_info['session_duration_formatted'] = str(duration)
-            result[session_id] = session_info
+        for conversation_id, start_time in self.conversation_start_times.items():
+            duration = datetime.now() - start_time
+            
+            # 从conversation_id解析user_id
+            user_id = "unknown"
+            if ":" in conversation_id:
+                user_id = conversation_id.split(":")[0]
+                
+            result[conversation_id] = {
+                'user_id': user_id,
+                'start_time': start_time,
+                'duration_seconds': duration.total_seconds(),
+                'duration_formatted': str(duration)
+            }
         return result
         return result
 
 
+    @staticmethod
+    def parse_conversation_id(conversation_id: str):
+        """解析conversation_id,返回user_id和timestamp"""
+        if ":" not in conversation_id:
+            return None, None
+        
+        parts = conversation_id.split(":", 1)
+        user_id = parts[0]
+        timestamp_str = parts[1]
+        
+        try:
+            # 解析时间戳:YYYYMMDDHHMMSSsss
+            if len(timestamp_str) == 17:  # 20250722204550155
+                timestamp = datetime.strptime(timestamp_str[:14], "%Y%m%d%H%M%S")
+                # 添加毫秒
+                milliseconds = int(timestamp_str[14:])
+                timestamp = timestamp.replace(microsecond=milliseconds * 1000)
+                return user_id, timestamp
+        except ValueError:
+            pass
+        
+        return user_id, None
 
 
-# 升级版:支持前端传递会话ID
-class WebSessionAwareMemoryCache(SessionAwareMemoryCache):
-    """支持从前端获取会话ID的版本"""
-    
-    def __init__(self):
-        super().__init__()
-        self.browser_sessions = {}  # browser_session_id -> our_session_id
-    
-    def register_browser_session(self, browser_session_id: str, user_info: dict = None):
-        """注册浏览器会话"""
-        if browser_session_id not in self.browser_sessions:
-            our_session_id = str(uuid.uuid4())
-            self.browser_sessions[browser_session_id] = our_session_id
-            
-            self.session_info[our_session_id] = {
-                'start_time': datetime.now(),
-                'last_activity': datetime.now(),
-                'conversations': [],
-                'browser_session_id': browser_session_id,
-                'user_info': user_info or {}
-            }
-        return self.browser_sessions[browser_session_id]
-    
-    def generate_id_with_browser_session(self, question: str = None, browser_session_id: str = None) -> str:
-        """使用浏览器会话ID生成对话ID"""
-        if browser_session_id:
-            our_session_id = self.register_browser_session(browser_session_id)
-        else:
-            our_session_id = self.create_or_get_session_id()
+    @staticmethod
+    def extract_user_id(conversation_id: str) -> str:
+        """从conversation_id中提取user_id"""
+        if ":" not in conversation_id:
+            return "unknown"
+        return conversation_id.split(":", 1)[0]
+
+    @staticmethod
+    def validate_user_id_consistency(conversation_id: str, provided_user_id: str) -> tuple[bool, str]:
+        """
+        校验conversation_id中的user_id与提供的user_id是否一致
+        
+        Returns:
+            tuple: (is_valid, error_message)
+        """
+        if not conversation_id or not provided_user_id:
+            return True, ""  # 如果任一为空,跳过校验
         
         
-        return super().generate_id(question=question, session_id=our_session_id)
+        extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id)
+        
+        if extracted_user_id != provided_user_id:
+            return False, f"用户ID不匹配:conversation_id中的用户ID '{extracted_user_id}' 与提供的用户ID '{provided_user_id}' 不一致"
+        
+        return True, ""
+
+
+# 保持向后兼容的别名
+WebSessionAwareMemoryCache = ConversationAwareMemoryCache
+SessionAwareMemoryCache = ConversationAwareMemoryCache

+ 2 - 2
config/logging_config_windows.yaml

@@ -59,7 +59,7 @@ modules:
     level: DEBUG
     level: DEBUG
     console:
     console:
       enabled: true
       enabled: true
-      level: INFO
+      level: DEBUG
       format: "%(asctime)s [%(levelname)s] Agent: %(message)s"
       format: "%(asctime)s [%(levelname)s] Agent: %(message)s"
     file:
     file:
       enabled: true
       enabled: true
@@ -75,7 +75,7 @@ modules:
     level: DEBUG
     level: DEBUG
     console:
     console:
       enabled: true
       enabled: true
-      level: INFO
+      level: DEBUG
       format: "%(asctime)s [%(levelname)s] Vanna: %(message)s"
       format: "%(asctime)s [%(levelname)s] Vanna: %(message)s"
     file:
     file:
       enabled: true
       enabled: true

+ 2 - 2
custompgvector/pgvector.py

@@ -663,7 +663,7 @@ class PG_VectorStore(VannaBase):
             
             
             # 检查原始查询结果是否为空
             # 检查原始查询结果是否为空
             if not results:
             if not results:
-                self.logger.warning(f"向量查询未找到任何相关的错误SQL示例,问题: {question}")
+                self.logger.warning(f"向量查询未找到任何相关的错误SQL示例")
 
 
             # 应用错误SQL特有的阈值过滤逻辑
             # 应用错误SQL特有的阈值过滤逻辑
             filtered_results = self._apply_error_sql_threshold_filter(results)
             filtered_results = self._apply_error_sql_threshold_filter(results)
@@ -671,7 +671,7 @@ class PG_VectorStore(VannaBase):
             # 检查过滤后结果是否为空
             # 检查过滤后结果是否为空
             if results and not filtered_results:
             if results and not filtered_results:
                 self.logger.warning(f"向量查询找到了 {len(results)} 条错误SQL示例,但全部被阈值过滤掉.")
                 self.logger.warning(f"向量查询找到了 {len(results)} 条错误SQL示例,但全部被阈值过滤掉.")
-                self.logger.warning(f"问题: {question}")
+                # self.logger.warning(f"问题: {question}")
 
 
             return filtered_results
             return filtered_results
             
             

+ 33 - 8
unified_api.py

@@ -711,7 +711,6 @@ def ask_agent():
     """支持对话上下文的ask_agent API"""
     """支持对话上下文的ask_agent API"""
     req = request.get_json(force=True)
     req = request.get_json(force=True)
     question = req.get("question", None)
     question = req.get("question", None)
-    browser_session_id = req.get("session_id", None)
     user_id_input = req.get("user_id", None)
     user_id_input = req.get("user_id", None)
     conversation_id_input = req.get("conversation_id", None)
     conversation_id_input = req.get("conversation_id", None)
     continue_conversation = req.get("continue_conversation", False)
     continue_conversation = req.get("continue_conversation", False)
@@ -735,9 +734,38 @@ def ask_agent():
         # 获取登录用户ID
         # 获取登录用户ID
         login_user_id = session.get('user_id') if 'user_id' in session else None
         login_user_id = session.get('user_id') if 'user_id' in session else None
         
         
+        # 用户ID和对话ID一致性校验
+        from common.session_aware_cache import ConversationAwareMemoryCache
+        
+        # 如果传递了conversation_id,从中解析user_id
+        extracted_user_id = None
+        if conversation_id_input:
+            extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
+            
+            # 如果同时传递了user_id和conversation_id,进行一致性校验
+            if user_id_input:
+                is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
+                    conversation_id_input, user_id_input
+                )
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=error_msg,
+                        invalid_params=["user_id", "conversation_id"]
+                    )), 400
+            
+            # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
+            elif not user_id_input and extracted_user_id:
+                user_id_input = extracted_user_id
+                logger.info(f"从conversation_id解析出user_id: {user_id_input}")
+        
+        # 如果没有传递user_id,使用默认值guest
+        if not user_id_input:
+            user_id_input = "guest"
+            logger.info("未传递user_id,使用默认值: guest")
+        
         # 智能ID解析
         # 智能ID解析
         user_id = redis_conversation_manager.resolve_user_id(
         user_id = redis_conversation_manager.resolve_user_id(
-            user_id_input, browser_session_id, request.remote_addr, login_user_id
+            user_id_input, None, request.remote_addr, login_user_id
         )
         )
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
         conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
             user_id, conversation_id_input, continue_conversation
             user_id, conversation_id_input, continue_conversation
@@ -810,10 +838,9 @@ def ask_agent():
                 sql=cached_answer.get("sql"),
                 sql=cached_answer.get("sql"),
                 records=cached_answer.get("query_result"),
                 records=cached_answer.get("query_result"),
                 summary=cached_answer.get("summary"),
                 summary=cached_answer.get("summary"),
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=cached_answer.get("execution_path", []),
                 execution_path=cached_answer.get("execution_path", []),
                 classification_info=cached_answer.get("classification_info", {}),
                 classification_info=cached_answer.get("classification_info", {}),
-                conversation_id=conversation_id,
                 user_id=user_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
                 context_used=bool(context),
@@ -863,7 +890,7 @@ def ask_agent():
         import asyncio
         import asyncio
         agent_result = asyncio.run(agent.process_question(
         agent_result = asyncio.run(agent.process_question(
             question=enhanced_question,  # 使用增强后的问题
             question=enhanced_question,  # 使用增强后的问题
-            session_id=browser_session_id,
+            conversation_id=conversation_id,
             context_type=context_type,  # 传递上下文类型
             context_type=context_type,  # 传递上下文类型
             routing_mode=effective_routing_mode  # 新增:传递路由模式
             routing_mode=effective_routing_mode  # 新增:传递路由模式
         ))
         ))
@@ -913,10 +940,9 @@ def ask_agent():
                 sql=sql,
                 sql=sql,
                 records=query_result,
                 records=query_result,
                 summary=summary,
                 summary=summary,
-                session_id=browser_session_id,
+                conversation_id=conversation_id,
                 execution_path=execution_path,
                 execution_path=execution_path,
                 classification_info=classification_info,
                 classification_info=classification_info,
-                conversation_id=conversation_id,
                 user_id=user_id,
                 user_id=user_id,
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
                 context_used=bool(context),
                 context_used=bool(context),
@@ -936,7 +962,6 @@ def ask_agent():
                 response_text=error_message,
                 response_text=error_message,
                 error_type="agent_processing_failed",
                 error_type="agent_processing_failed",
                 code=error_code,
                 code=error_code,
-                session_id=browser_session_id,
                 conversation_id=conversation_id,
                 conversation_id=conversation_id,
                 user_id=user_id
                 user_id=user_id
             )), error_code
             )), error_code