redis_conversation_manager.py 23 KB


  1. import redis
  2. import json
  3. import hashlib
  4. import uuid
  5. import time
  6. from datetime import datetime
  7. from typing import List, Dict, Any, Optional
  8. from app_config import (
  9. REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD,
  10. CONVERSATION_CONTEXT_COUNT, CONVERSATION_MAX_LENGTH, USER_MAX_CONVERSATIONS,
  11. CONVERSATION_TTL, USER_CONVERSATIONS_TTL, QUESTION_ANSWER_TTL,
  12. ENABLE_CONVERSATION_CONTEXT, ENABLE_QUESTION_ANSWER_CACHE,
  13. DEFAULT_ANONYMOUS_USER
  14. )
  15. class RedisConversationManager:
  16. """Redis对话管理器 - 修正版"""
  17. def __init__(self):
  18. """初始化Redis连接"""
  19. try:
  20. self.redis_client = redis.Redis(
  21. host=REDIS_HOST,
  22. port=REDIS_PORT,
  23. db=REDIS_DB,
  24. password=REDIS_PASSWORD,
  25. decode_responses=True,
  26. socket_connect_timeout=5,
  27. socket_timeout=5
  28. )
  29. # 测试连接
  30. self.redis_client.ping()
  31. print(f"[REDIS_CONV] Redis连接成功: {REDIS_HOST}:{REDIS_PORT}")
  32. except Exception as e:
  33. print(f"[ERROR] Redis连接失败: {str(e)}")
  34. self.redis_client = None
  35. def is_available(self) -> bool:
  36. """检查Redis是否可用"""
  37. try:
  38. return self.redis_client is not None and self.redis_client.ping()
  39. except:
  40. return False
  41. # ==================== 用户ID解析(修正版)====================
  42. def resolve_user_id(self, user_id_from_request: Optional[str],
  43. session_id: Optional[str], request_ip: str,
  44. login_user_id: Optional[str] = None) -> str:
  45. """
  46. 智能解析用户ID - 统一版
  47. Args:
  48. user_id_from_request: 请求参数中的user_id
  49. session_id: 浏览器session_id
  50. request_ip: 请求IP地址
  51. login_user_id: 从Flask session中获取的登录用户ID
  52. """
  53. # 1. 优先使用登录用户ID
  54. if login_user_id:
  55. print(f"[REDIS_CONV] 使用登录用户ID: {login_user_id}")
  56. return login_user_id
  57. # 2. 如果没有登录,尝试从请求参数获取user_id
  58. if user_id_from_request:
  59. print(f"[REDIS_CONV] 使用请求参数user_id: {user_id_from_request}")
  60. return user_id_from_request
  61. # 3. 都没有则为匿名用户(统一为guest)
  62. print(f"[REDIS_CONV] 使用匿名用户: {DEFAULT_ANONYMOUS_USER}")
  63. return DEFAULT_ANONYMOUS_USER
  64. def resolve_conversation_id(self, user_id: str, conversation_id_input: Optional[str],
  65. continue_conversation: bool) -> tuple[str, dict]:
  66. """
  67. 智能解析对话ID - 改进版
  68. Returns:
  69. tuple: (conversation_id, status_info)
  70. status_info包含:
  71. - status: "existing" | "new" | "invalid_id_new"
  72. - message: 状态说明
  73. - requested_id: 原始请求的ID(如果有)
  74. """
  75. # 1. 如果指定了conversation_id,验证后使用
  76. if conversation_id_input:
  77. if self._is_valid_conversation(conversation_id_input, user_id):
  78. print(f"[REDIS_CONV] 使用指定对话: {conversation_id_input}")
  79. return conversation_id_input, {
  80. "status": "existing",
  81. "message": "继续已有对话"
  82. }
  83. else:
  84. print(f"[WARN] 无效的conversation_id: {conversation_id_input},创建新对话")
  85. new_conversation_id = self.create_conversation(user_id)
  86. return new_conversation_id, {
  87. "status": "invalid_id_new",
  88. "message": "您请求的对话不存在或无权访问,已为您创建新对话",
  89. "requested_id": conversation_id_input
  90. }
  91. # 2. 如果要继续最近对话
  92. if continue_conversation:
  93. recent_conversation = self._get_recent_conversation(user_id)
  94. if recent_conversation:
  95. print(f"[REDIS_CONV] 继续最近对话: {recent_conversation}")
  96. return recent_conversation, {
  97. "status": "existing",
  98. "message": "继续最近对话"
  99. }
  100. # 3. 创建新对话
  101. new_conversation_id = self.create_conversation(user_id)
  102. print(f"[REDIS_CONV] 创建新对话: {new_conversation_id}")
  103. return new_conversation_id, {
  104. "status": "new",
  105. "message": "创建新对话"
  106. }
  107. def _is_valid_conversation(self, conversation_id: str, user_id: str) -> bool:
  108. """验证对话是否存在且属于该用户"""
  109. if not self.is_available():
  110. return False
  111. try:
  112. # 检查对话元信息是否存在
  113. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  114. if not meta_data:
  115. return False
  116. # 检查是否属于该用户
  117. return meta_data.get('user_id') == user_id
  118. except Exception:
  119. return False
  120. def _get_recent_conversation(self, user_id: str) -> Optional[str]:
  121. """获取用户最近的对话ID"""
  122. if not self.is_available():
  123. return None
  124. try:
  125. conversations = self.redis_client.lrange(
  126. f"user:{user_id}:conversations", 0, 0
  127. )
  128. return conversations[0] if conversations else None
  129. except Exception:
  130. return None
  131. # ==================== 对话管理 ====================
  132. def create_conversation(self, user_id: str) -> str:
  133. """创建新对话"""
  134. # 生成包含时间戳的conversation_id
  135. timestamp = int(datetime.now().timestamp())
  136. conversation_id = f"conv_{timestamp}_{uuid.uuid4().hex[:8]}"
  137. if not self.is_available():
  138. return conversation_id # Redis不可用时返回ID,但不存储
  139. try:
  140. # 创建对话元信息
  141. meta_data = {
  142. "conversation_id": conversation_id,
  143. "user_id": user_id,
  144. "created_at": datetime.now().isoformat(),
  145. "updated_at": datetime.now().isoformat(),
  146. "message_count": "0"
  147. }
  148. # 保存对话元信息
  149. self.redis_client.hset(
  150. f"conversation:{conversation_id}:meta",
  151. mapping=meta_data
  152. )
  153. self.redis_client.expire(f"conversation:{conversation_id}:meta", CONVERSATION_TTL)
  154. # 添加到用户的对话列表
  155. self._add_conversation_to_user(user_id, conversation_id)
  156. print(f"[REDIS_CONV] 创建对话成功: {conversation_id}")
  157. return conversation_id
  158. except Exception as e:
  159. print(f"[ERROR] 创建对话失败: {str(e)}")
  160. return conversation_id # 返回ID但可能未存储
  161. def save_message(self, conversation_id: str, role: str, content: str,
  162. metadata: Optional[Dict] = None) -> bool:
  163. """保存消息到对话历史"""
  164. if not self.is_available() or not conversation_id:
  165. return False
  166. try:
  167. message_data = {
  168. "message_id": str(uuid.uuid4()),
  169. "timestamp": datetime.now().isoformat(),
  170. "role": role, # user, assistant
  171. "content": content,
  172. "metadata": metadata or {}
  173. }
  174. # 保存到消息列表(LPUSH添加到头部,最新消息在前)
  175. self.redis_client.lpush(
  176. f"conversation:{conversation_id}:messages",
  177. json.dumps(message_data)
  178. )
  179. # 设置TTL
  180. self.redis_client.expire(f"conversation:{conversation_id}:messages", CONVERSATION_TTL)
  181. # 限制消息数量
  182. self.redis_client.ltrim(
  183. f"conversation:{conversation_id}:messages",
  184. 0, CONVERSATION_MAX_LENGTH - 1
  185. )
  186. # 更新元信息
  187. self._update_conversation_meta(conversation_id)
  188. return True
  189. except Exception as e:
  190. print(f"[ERROR] 保存消息失败: {str(e)}")
  191. return False
  192. def get_context(self, conversation_id: str, count: Optional[int] = None) -> str:
  193. """获取对话上下文,格式化为prompt"""
  194. if not self.is_available() or not ENABLE_CONVERSATION_CONTEXT:
  195. return ""
  196. try:
  197. if count is None:
  198. count = CONVERSATION_CONTEXT_COUNT
  199. # 获取最近的消息(count*2 因为包含用户和助手消息)
  200. message_count = count * 2
  201. messages = self.redis_client.lrange(
  202. f"conversation:{conversation_id}:messages",
  203. 0, message_count - 1
  204. )
  205. if not messages:
  206. return ""
  207. # 解析消息并构建上下文(按时间正序)
  208. context_parts = []
  209. for msg_json in reversed(messages): # Redis返回倒序,需要反转
  210. try:
  211. msg_data = json.loads(msg_json)
  212. role = msg_data.get("role", "")
  213. content = msg_data.get("content", "")
  214. if role == "user":
  215. context_parts.append(f"User: {content}")
  216. elif role == "assistant":
  217. context_parts.append(f"Assistant: {content}")
  218. except json.JSONDecodeError:
  219. continue
  220. context = "\n".join(context_parts)
  221. print(f"[REDIS_CONV] 获取上下文成功: {len(context_parts)}条消息")
  222. return context
  223. except Exception as e:
  224. print(f"[ERROR] 获取上下文失败: {str(e)}")
  225. return ""
  226. def get_conversation_messages(self, conversation_id: str, limit: Optional[int] = None) -> List[Dict]:
  227. """获取对话的消息列表"""
  228. if not self.is_available():
  229. return []
  230. try:
  231. if limit:
  232. messages = self.redis_client.lrange(
  233. f"conversation:{conversation_id}:messages", 0, limit - 1
  234. )
  235. else:
  236. messages = self.redis_client.lrange(
  237. f"conversation:{conversation_id}:messages", 0, -1
  238. )
  239. # 解析并按时间正序返回
  240. parsed_messages = []
  241. for msg_json in reversed(messages): # 反转为时间正序
  242. try:
  243. parsed_messages.append(json.loads(msg_json))
  244. except json.JSONDecodeError:
  245. continue
  246. return parsed_messages
  247. except Exception as e:
  248. print(f"[ERROR] 获取对话消息失败: {str(e)}")
  249. return []
  250. def get_conversation_meta(self, conversation_id: str) -> Dict:
  251. """获取对话元信息"""
  252. if not self.is_available():
  253. return {}
  254. try:
  255. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  256. return meta_data if meta_data else {}
  257. except Exception as e:
  258. print(f"[ERROR] 获取对话元信息失败: {str(e)}")
  259. return {}
  260. def get_conversations(self, user_id: str, limit: int = None) -> List[Dict]:
  261. """获取用户的对话列表(按时间倒序)"""
  262. if not self.is_available():
  263. return []
  264. if limit is None:
  265. limit = USER_MAX_CONVERSATIONS
  266. try:
  267. # 获取对话ID列表(已经按时间倒序)
  268. conversation_ids = self.redis_client.lrange(
  269. f"user:{user_id}:conversations", 0, limit - 1
  270. )
  271. conversations = []
  272. for conv_id in conversation_ids:
  273. meta_data = self.get_conversation_meta(conv_id)
  274. if meta_data: # 只返回仍然存在的对话
  275. conversations.append(meta_data)
  276. return conversations
  277. except Exception as e:
  278. print(f"[ERROR] 获取用户对话列表失败: {str(e)}")
  279. return []
  280. # ==================== 智能缓存(修正版)====================
  281. def get_cached_answer(self, question: str, context: str = "") -> Optional[Dict]:
  282. """检查问答缓存 - 放宽使用条件,无论是否有上下文都可以查找缓存"""
  283. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  284. return None
  285. # 移除上下文检查,允许任何情况下查找缓存
  286. try:
  287. cache_key = self._get_cache_key(question)
  288. cached_answer = self.redis_client.get(cache_key)
  289. if cached_answer:
  290. context_info = "有上下文" if context else "无上下文"
  291. print(f"[REDIS_CONV] 缓存命中: {cache_key} ({context_info})")
  292. return json.loads(cached_answer)
  293. return None
  294. except Exception as e:
  295. print(f"[ERROR] 获取缓存答案失败: {str(e)}")
  296. return None
  297. def cache_answer(self, question: str, answer: Dict, context: str = ""):
  298. """缓存问答结果 - 只缓存无上下文的问答"""
  299. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  300. return
  301. # 新增:如果有上下文,不缓存
  302. if context:
  303. print(f"[REDIS_CONV] 跳过缓存存储:存在上下文")
  304. return
  305. try:
  306. cache_key = self._get_cache_key(question)
  307. # 添加缓存时间戳
  308. answer_with_meta = {
  309. **answer,
  310. "cached_at": datetime.now().isoformat(),
  311. "original_question": question
  312. }
  313. # 使用独立key,每个缓存项单独设置TTL
  314. self.redis_client.setex(
  315. cache_key,
  316. QUESTION_ANSWER_TTL,
  317. json.dumps(answer_with_meta)
  318. )
  319. print(f"[REDIS_CONV] 缓存答案成功: {cache_key}")
  320. except Exception as e:
  321. print(f"[ERROR] 缓存答案失败: {str(e)}")
  322. def _get_cache_key(self, question: str) -> str:
  323. """生成缓存键 - 简化版,只基于问题本身"""
  324. normalized = question.strip().lower()
  325. question_hash = hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
  326. return f"qa_cache:{question_hash}"
  327. # ==================== 私有方法 ====================
  328. def _add_conversation_to_user(self, user_id: str, conversation_id: str):
  329. """添加对话到用户列表,按时间自动排序"""
  330. try:
  331. # LPUSH添加到列表头部(最新的)
  332. self.redis_client.lpush(f"user:{user_id}:conversations", conversation_id)
  333. # 统一限制数量
  334. self.redis_client.ltrim(
  335. f"user:{user_id}:conversations",
  336. 0, USER_MAX_CONVERSATIONS - 1
  337. )
  338. # 统一设置TTL
  339. self.redis_client.expire(
  340. f"user:{user_id}:conversations",
  341. USER_CONVERSATIONS_TTL
  342. )
  343. except Exception as e:
  344. print(f"[ERROR] 添加对话到用户列表失败: {str(e)}")
  345. def _update_conversation_meta(self, conversation_id: str):
  346. """更新对话元信息"""
  347. try:
  348. # 获取消息数量
  349. message_count = self.redis_client.llen(f"conversation:{conversation_id}:messages")
  350. # 更新元信息
  351. self.redis_client.hset(
  352. f"conversation:{conversation_id}:meta",
  353. mapping={
  354. "updated_at": datetime.now().isoformat(),
  355. "message_count": str(message_count)
  356. }
  357. )
  358. except Exception as e:
  359. print(f"[ERROR] 更新对话元信息失败: {str(e)}")
  360. # ==================== 管理方法 ====================
  361. def get_stats(self) -> Dict:
  362. """获取统计信息"""
  363. if not self.is_available():
  364. return {"available": False}
  365. try:
  366. # 获取Redis原始内存信息
  367. redis_info = self.redis_client.info()
  368. used_memory_bytes = redis_info.get("used_memory", 0)
  369. stats = {
  370. "available": True,
  371. "total_users": len(self.redis_client.keys("user:*:conversations")),
  372. "total_conversations": len(self.redis_client.keys("conversation:*:meta")),
  373. "cached_qa_count": len(self.redis_client.keys("qa_cache:*")), # 修正缓存统计
  374. "redis_info": {
  375. "memory_usage_mb": round(used_memory_bytes / (1024 * 1024), 2), # 统一使用MB格式
  376. "connected_clients": redis_info.get("connected_clients")
  377. }
  378. }
  379. return stats
  380. except Exception as e:
  381. print(f"[ERROR] 获取统计信息失败: {str(e)}")
  382. return {"available": False, "error": str(e)}
  383. def cleanup_expired_conversations(self):
  384. """清理过期对话(Redis TTL自动处理,这里可添加额外逻辑)"""
  385. if not self.is_available():
  386. return
  387. try:
  388. # 清理用户对话列表中的无效对话ID
  389. user_keys = self.redis_client.keys("user:*:conversations")
  390. cleaned_count = 0
  391. for user_key in user_keys:
  392. conversation_ids = self.redis_client.lrange(user_key, 0, -1)
  393. valid_ids = []
  394. for conv_id in conversation_ids:
  395. # 检查对话是否仍然存在
  396. if self.redis_client.exists(f"conversation:{conv_id}:meta"):
  397. valid_ids.append(conv_id)
  398. else:
  399. cleaned_count += 1
  400. # 如果有无效ID,重建列表
  401. if len(valid_ids) != len(conversation_ids):
  402. self.redis_client.delete(user_key)
  403. if valid_ids:
  404. self.redis_client.lpush(user_key, *reversed(valid_ids))
  405. # 重新设置TTL
  406. self.redis_client.expire(user_key, USER_CONVERSATIONS_TTL)
  407. print(f"[REDIS_CONV] 清理完成,移除了 {cleaned_count} 个无效对话引用")
  408. except Exception as e:
  409. print(f"[ERROR] 清理失败: {str(e)}")
  410. # ==================== 问答缓存管理方法 ====================
  411. def get_qa_cache_stats(self) -> Dict:
  412. """获取问答缓存详细统计信息"""
  413. if not self.is_available():
  414. return {"available": False}
  415. try:
  416. pattern = "qa_cache:*"
  417. keys = self.redis_client.keys(pattern)
  418. stats = {
  419. "available": True,
  420. "enabled": ENABLE_QUESTION_ANSWER_CACHE,
  421. "total_count": len(keys),
  422. "memory_usage_mb": 0,
  423. "ttl_seconds": QUESTION_ANSWER_TTL,
  424. "ttl_hours": QUESTION_ANSWER_TTL / 3600
  425. }
  426. # 估算内存使用量
  427. if keys:
  428. sample_key = keys[0]
  429. sample_data = self.redis_client.get(sample_key)
  430. if sample_data:
  431. avg_size_bytes = len(sample_data.encode('utf-8'))
  432. total_size_bytes = avg_size_bytes * len(keys)
  433. stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
  434. return stats
  435. except Exception as e:
  436. print(f"[ERROR] 获取问答缓存统计失败: {str(e)}")
  437. return {"available": False, "error": str(e)}
  438. def get_qa_cache_list(self, limit: int = 50) -> List[Dict]:
  439. """获取问答缓存列表(支持分页)"""
  440. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  441. return []
  442. try:
  443. pattern = "qa_cache:*"
  444. keys = self.redis_client.keys(pattern)
  445. # 限制返回数量
  446. if limit > 0:
  447. keys = keys[:limit]
  448. cache_list = []
  449. for key in keys:
  450. try:
  451. cached_data = self.redis_client.get(key)
  452. if cached_data:
  453. data = json.loads(cached_data)
  454. # 获取TTL信息
  455. ttl = self.redis_client.ttl(key)
  456. cache_item = {
  457. "cache_key": key,
  458. "question": data.get("original_question", "未知问题"),
  459. "cached_at": data.get("cached_at", "未知时间"),
  460. "ttl_seconds": ttl,
  461. "response_type": data.get("data", {}).get("type", "未知类型"),
  462. "has_sql": bool(data.get("data", {}).get("sql")),
  463. "has_summary": bool(data.get("data", {}).get("summary"))
  464. }
  465. cache_list.append(cache_item)
  466. except json.JSONDecodeError:
  467. # 跳过无效的JSON数据
  468. continue
  469. except Exception as e:
  470. print(f"[WARNING] 处理缓存项 {key} 失败: {e}")
  471. continue
  472. # 按缓存时间倒序排列
  473. cache_list.sort(key=lambda x: x.get("cached_at", ""), reverse=True)
  474. return cache_list
  475. except Exception as e:
  476. print(f"[ERROR] 获取问答缓存列表失败: {str(e)}")
  477. return []
  478. def clear_all_qa_cache(self) -> int:
  479. """清空所有问答缓存,返回删除的数量"""
  480. if not self.is_available():
  481. return 0
  482. try:
  483. pattern = "qa_cache:*"
  484. keys = self.redis_client.keys(pattern)
  485. if keys:
  486. deleted_count = self.redis_client.delete(*keys)
  487. print(f"[REDIS_CONV] 清空问答缓存成功,删除了 {deleted_count} 个缓存项")
  488. return deleted_count
  489. else:
  490. print(f"[REDIS_CONV] 没有找到问答缓存项")
  491. return 0
  492. except Exception as e:
  493. print(f"[ERROR] 清空问答缓存失败: {str(e)}")
  494. return 0