|
@@ -9,7 +9,7 @@ from flask import request, jsonify
|
|
import pandas as pd
|
|
import pandas as pd
|
|
import common.result as result
|
|
import common.result as result
|
|
from datetime import datetime, timedelta
|
|
from datetime import datetime, timedelta
|
|
-from common.session_aware_cache import WebSessionAwareMemoryCache
|
|
|
|
|
|
+from common.session_aware_cache import ConversationAwareMemoryCache
|
|
from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
|
|
from app_config import API_MAX_RETURN_ROWS, ENABLE_RESULT_SUMMARY
|
|
import re
|
|
import re
|
|
import chainlit as cl
|
|
import chainlit as cl
|
|
@@ -41,13 +41,13 @@ MAX_RETURN_ROWS = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DE
|
|
|
|
|
|
vn = create_vanna_instance()
|
|
vn = create_vanna_instance()
|
|
|
|
|
|
-# 创建带时间戳的缓存
|
|
|
|
-timestamped_cache = WebSessionAwareMemoryCache()
|
|
|
|
|
|
+# 创建对话感知的缓存
|
|
|
|
+conversation_cache = ConversationAwareMemoryCache()
|
|
|
|
|
|
# 实例化 VannaFlaskApp,使用自定义缓存
|
|
# 实例化 VannaFlaskApp,使用自定义缓存
|
|
app = VannaFlaskApp(
|
|
app = VannaFlaskApp(
|
|
vn,
|
|
vn,
|
|
- cache=timestamped_cache, # 使用带时间戳的缓存
|
|
|
|
|
|
+ cache=conversation_cache, # 使用对话感知的缓存
|
|
title="辞图智能数据问答平台",
|
|
title="辞图智能数据问答平台",
|
|
logo = "https://www.citupro.com/img/logo-black-2.png",
|
|
logo = "https://www.citupro.com/img/logo-black-2.png",
|
|
subtitle="让 AI 为你写 SQL",
|
|
subtitle="让 AI 为你写 SQL",
|
|
@@ -61,12 +61,13 @@ app = VannaFlaskApp(
|
|
# 创建Redis对话管理器实例
|
|
# 创建Redis对话管理器实例
|
|
redis_conversation_manager = RedisConversationManager()
|
|
redis_conversation_manager = RedisConversationManager()
|
|
|
|
|
|
-# 修改ask接口,支持前端传递session_id
|
|
|
|
|
|
+# 修改ask接口,支持前端传递conversation_id
|
|
@app.flask_app.route('/api/v0/ask', methods=['POST'])
|
|
@app.flask_app.route('/api/v0/ask', methods=['POST'])
|
|
def ask_full():
|
|
def ask_full():
|
|
req = request.get_json(force=True)
|
|
req = request.get_json(force=True)
|
|
question = req.get("question", None)
|
|
question = req.get("question", None)
|
|
- browser_session_id = req.get("session_id", None) # 前端传递的会话ID
|
|
|
|
|
|
+ conversation_id = req.get("conversation_id", None) # 前端传递的对话ID
|
|
|
|
+ user_id = req.get("user_id", None) # 前端传递的用户ID
|
|
|
|
|
|
if not question:
|
|
if not question:
|
|
from common.result import bad_request_response
|
|
from common.result import bad_request_response
|
|
@@ -75,14 +76,13 @@ def ask_full():
|
|
missing_params=["question"]
|
|
missing_params=["question"]
|
|
)), 400
|
|
)), 400
|
|
|
|
|
|
- # 如果使用WebSessionAwareMemoryCache
|
|
|
|
- if hasattr(app.cache, 'generate_id_with_browser_session') and browser_session_id:
|
|
|
|
- # 这里需要修改vanna的ask方法来支持传递session_id
|
|
|
|
- # 或者预先调用generate_id来建立会话关联
|
|
|
|
- conversation_id = app.cache.generate_id_with_browser_session(
|
|
|
|
- question=question,
|
|
|
|
- browser_session_id=browser_session_id
|
|
|
|
- )
|
|
|
|
|
|
+ # 如果没有传递user_id,使用默认值guest
|
|
|
|
+ if not user_id:
|
|
|
|
+ user_id = "guest"
|
|
|
|
+
|
|
|
|
+ # 如果前端没有传递conversation_id,则生成新的
|
|
|
|
+ if not conversation_id:
|
|
|
|
+ conversation_id = app.cache.generate_id(question=question, user_id=user_id)
|
|
|
|
|
|
try:
|
|
try:
|
|
sql, df, _ = vn.ask(
|
|
sql, df, _ = vn.ask(
|
|
@@ -144,8 +144,7 @@ def ask_full():
|
|
response_data = {
|
|
response_data = {
|
|
"sql": sql,
|
|
"sql": sql,
|
|
"query_result": query_result,
|
|
"query_result": query_result,
|
|
- "conversation_id": conversation_id if 'conversation_id' in locals() else None,
|
|
|
|
- "session_id": browser_session_id
|
|
|
|
|
|
+ "conversation_id": conversation_id
|
|
}
|
|
}
|
|
|
|
|
|
# 添加摘要(如果启用且生成成功)
|
|
# 添加摘要(如果启用且生成成功)
|
|
@@ -237,7 +236,8 @@ def ask_cached():
|
|
"""
|
|
"""
|
|
req = request.get_json(force=True)
|
|
req = request.get_json(force=True)
|
|
question = req.get("question", None)
|
|
question = req.get("question", None)
|
|
- browser_session_id = req.get("session_id", None)
|
|
|
|
|
|
+ conversation_id = req.get("conversation_id", None)
|
|
|
|
+ user_id = req.get("user_id", None)
|
|
|
|
|
|
if not question:
|
|
if not question:
|
|
from common.result import bad_request_response
|
|
from common.result import bad_request_response
|
|
@@ -246,15 +246,19 @@ def ask_cached():
|
|
missing_params=["question"]
|
|
missing_params=["question"]
|
|
)), 400
|
|
)), 400
|
|
|
|
|
|
|
|
+ # 如果没有传递user_id,使用默认值guest
|
|
|
|
+ if not user_id:
|
|
|
|
+ user_id = "guest"
|
|
|
|
+
|
|
try:
|
|
try:
|
|
# 生成conversation_id
|
|
# 生成conversation_id
|
|
# 调试:查看generate_id的实际行为
|
|
# 调试:查看generate_id的实际行为
|
|
logger.debug(f"输入问题: '{question}'")
|
|
logger.debug(f"输入问题: '{question}'")
|
|
- conversation_id = app.cache.generate_id(question=question)
|
|
|
|
|
|
+ conversation_id = app.cache.generate_id(question=question, user_id=user_id)
|
|
logger.debug(f"生成的conversation_id: {conversation_id}")
|
|
logger.debug(f"生成的conversation_id: {conversation_id}")
|
|
|
|
|
|
# 再次用相同问题测试
|
|
# 再次用相同问题测试
|
|
- conversation_id2 = app.cache.generate_id(question=question)
|
|
|
|
|
|
+ conversation_id2 = app.cache.generate_id(question=question, user_id=user_id)
|
|
logger.debug(f"再次生成的conversation_id: {conversation_id2}")
|
|
logger.debug(f"再次生成的conversation_id: {conversation_id2}")
|
|
logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
|
|
logger.debug(f"两次ID是否相同: {conversation_id == conversation_id2}")
|
|
|
|
|
|
@@ -336,7 +340,6 @@ def ask_cached():
|
|
"sql": sql,
|
|
"sql": sql,
|
|
"query_result": query_result,
|
|
"query_result": query_result,
|
|
"conversation_id": conversation_id,
|
|
"conversation_id": conversation_id,
|
|
- "session_id": browser_session_id,
|
|
|
|
"cached": cached_sql is not None # 标识是否来自缓存
|
|
"cached": cached_sql is not None # 标识是否来自缓存
|
|
}
|
|
}
|
|
|
|
|
|
@@ -449,11 +452,10 @@ def ask_agent():
|
|
"""
|
|
"""
|
|
req = request.get_json(force=True)
|
|
req = request.get_json(force=True)
|
|
question = req.get("question", None)
|
|
question = req.get("question", None)
|
|
- browser_session_id = req.get("session_id", None)
|
|
|
|
|
|
+ conversation_id_input = req.get("conversation_id", None)
|
|
|
|
|
|
# 新增参数解析
|
|
# 新增参数解析
|
|
user_id_input = req.get("user_id", None)
|
|
user_id_input = req.get("user_id", None)
|
|
- conversation_id_input = req.get("conversation_id", None)
|
|
|
|
continue_conversation = req.get("continue_conversation", False)
|
|
continue_conversation = req.get("continue_conversation", False)
|
|
|
|
|
|
# 新增:路由模式参数解析和验证
|
|
# 新增:路由模式参数解析和验证
|
|
@@ -477,9 +479,38 @@ def ask_agent():
|
|
# 1. 获取登录用户ID(修正:在函数中获取session信息)
|
|
# 1. 获取登录用户ID(修正:在函数中获取session信息)
|
|
login_user_id = session.get('user_id') if 'user_id' in session else None
|
|
login_user_id = session.get('user_id') if 'user_id' in session else None
|
|
|
|
|
|
- # 2. 智能ID解析(修正:传入登录用户ID)
|
|
|
|
|
|
+ # 2. 用户ID和对话ID一致性校验
|
|
|
|
+ from common.session_aware_cache import ConversationAwareMemoryCache
|
|
|
|
+
|
|
|
|
+ # 2.1 如果传递了conversation_id,从中解析user_id
|
|
|
|
+ extracted_user_id = None
|
|
|
|
+ if conversation_id_input:
|
|
|
|
+ extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id_input)
|
|
|
|
+
|
|
|
|
+ # 如果同时传递了user_id和conversation_id,进行一致性校验
|
|
|
|
+ if user_id_input:
|
|
|
|
+ is_valid, error_msg = ConversationAwareMemoryCache.validate_user_id_consistency(
|
|
|
|
+ conversation_id_input, user_id_input
|
|
|
|
+ )
|
|
|
|
+ if not is_valid:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text=error_msg,
|
|
|
|
+ invalid_params=["user_id", "conversation_id"]
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ # 如果没有传递user_id,但有conversation_id,则从conversation_id中解析
|
|
|
|
+ elif not user_id_input and extracted_user_id:
|
|
|
|
+ user_id_input = extracted_user_id
|
|
|
|
+ logger.info(f"从conversation_id解析出user_id: {user_id_input}")
|
|
|
|
+
|
|
|
|
+ # 2.2 如果没有传递user_id,使用默认值guest
|
|
|
|
+ if not user_id_input:
|
|
|
|
+ user_id_input = "guest"
|
|
|
|
+ logger.info("未传递user_id,使用默认值: guest")
|
|
|
|
+
|
|
|
|
+ # 3. 智能ID解析(修正:传入登录用户ID)
|
|
user_id = redis_conversation_manager.resolve_user_id(
|
|
user_id = redis_conversation_manager.resolve_user_id(
|
|
- user_id_input, browser_session_id, request.remote_addr, login_user_id
|
|
|
|
|
|
+ user_id_input, None, request.remote_addr, login_user_id
|
|
)
|
|
)
|
|
conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
|
|
conversation_id, conversation_status = redis_conversation_manager.resolve_conversation_id(
|
|
user_id, conversation_id_input, continue_conversation
|
|
user_id, conversation_id_input, continue_conversation
|
|
@@ -552,10 +583,9 @@ def ask_agent():
|
|
sql=cached_answer.get("sql"),
|
|
sql=cached_answer.get("sql"),
|
|
records=cached_answer.get("query_result"), # 修改:query_result改为records
|
|
records=cached_answer.get("query_result"), # 修改:query_result改为records
|
|
summary=cached_answer.get("summary"),
|
|
summary=cached_answer.get("summary"),
|
|
- session_id=browser_session_id,
|
|
|
|
|
|
+ conversation_id=conversation_id,
|
|
execution_path=cached_answer.get("execution_path", []),
|
|
execution_path=cached_answer.get("execution_path", []),
|
|
classification_info=cached_answer.get("classification_info", {}),
|
|
classification_info=cached_answer.get("classification_info", {}),
|
|
- conversation_id=conversation_id,
|
|
|
|
user_id=user_id,
|
|
user_id=user_id,
|
|
is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
|
|
is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
|
|
context_used=bool(context),
|
|
context_used=bool(context),
|
|
@@ -605,7 +635,7 @@ def ask_agent():
|
|
import asyncio
|
|
import asyncio
|
|
agent_result = asyncio.run(agent.process_question(
|
|
agent_result = asyncio.run(agent.process_question(
|
|
question=enhanced_question, # 使用增强后的问题
|
|
question=enhanced_question, # 使用增强后的问题
|
|
- session_id=browser_session_id,
|
|
|
|
|
|
+ conversation_id=conversation_id,
|
|
context_type=context_type, # 传递上下文类型
|
|
context_type=context_type, # 传递上下文类型
|
|
routing_mode=effective_routing_mode # 新增:传递路由模式
|
|
routing_mode=effective_routing_mode # 新增:传递路由模式
|
|
))
|
|
))
|
|
@@ -662,10 +692,9 @@ def ask_agent():
|
|
sql=sql,
|
|
sql=sql,
|
|
records=query_result, # 修改:query_result改为records
|
|
records=query_result, # 修改:query_result改为records
|
|
summary=summary,
|
|
summary=summary,
|
|
- session_id=browser_session_id,
|
|
|
|
|
|
+ conversation_id=conversation_id,
|
|
execution_path=execution_path,
|
|
execution_path=execution_path,
|
|
classification_info=classification_info,
|
|
classification_info=classification_info,
|
|
- conversation_id=conversation_id,
|
|
|
|
user_id=user_id,
|
|
user_id=user_id,
|
|
is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
|
|
is_guest_user=(user_id == DEFAULT_ANONYMOUS_USER),
|
|
context_used=bool(context),
|
|
context_used=bool(context),
|
|
@@ -685,7 +714,6 @@ def ask_agent():
|
|
response_text=error_message,
|
|
response_text=error_message,
|
|
error_type="agent_processing_failed",
|
|
error_type="agent_processing_failed",
|
|
code=error_code,
|
|
code=error_code,
|
|
- session_id=browser_session_id,
|
|
|
|
conversation_id=conversation_id,
|
|
conversation_id=conversation_id,
|
|
user_id=user_id
|
|
user_id=user_id
|
|
)), error_code
|
|
)), error_code
|