redis_conversation_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. import redis
  2. import json
  3. import hashlib
  4. import uuid
  5. import time
  6. from datetime import datetime
  7. from typing import List, Dict, Any, Optional
  8. from app_config import (
  9. REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD,
  10. CONVERSATION_CONTEXT_COUNT, CONVERSATION_MAX_LENGTH, USER_MAX_CONVERSATIONS,
  11. CONVERSATION_TTL, USER_CONVERSATIONS_TTL, QUESTION_ANSWER_TTL,
  12. ENABLE_CONVERSATION_CONTEXT, ENABLE_QUESTION_ANSWER_CACHE,
  13. DEFAULT_ANONYMOUS_USER
  14. )
  15. class RedisConversationManager:
  16. """Redis对话管理器 - 修正版"""
  17. def __init__(self):
  18. """初始化Redis连接"""
  19. try:
  20. self.redis_client = redis.Redis(
  21. host=REDIS_HOST,
  22. port=REDIS_PORT,
  23. db=REDIS_DB,
  24. password=REDIS_PASSWORD,
  25. decode_responses=True,
  26. socket_connect_timeout=5,
  27. socket_timeout=5
  28. )
  29. # 测试连接
  30. self.redis_client.ping()
  31. print(f"[REDIS_CONV] Redis连接成功: {REDIS_HOST}:{REDIS_PORT}")
  32. except Exception as e:
  33. print(f"[ERROR] Redis连接失败: {str(e)}")
  34. self.redis_client = None
  35. def is_available(self) -> bool:
  36. """检查Redis是否可用"""
  37. try:
  38. return self.redis_client is not None and self.redis_client.ping()
  39. except:
  40. return False
  41. # ==================== 用户ID解析(修正版)====================
  42. def resolve_user_id(self, user_id_from_request: Optional[str],
  43. session_id: Optional[str], request_ip: str,
  44. login_user_id: Optional[str] = None) -> str:
  45. """
  46. 智能解析用户ID - 统一版
  47. Args:
  48. user_id_from_request: 请求参数中的user_id
  49. session_id: 浏览器session_id
  50. request_ip: 请求IP地址
  51. login_user_id: 从Flask session中获取的登录用户ID
  52. """
  53. # 1. 优先使用登录用户ID
  54. if login_user_id:
  55. print(f"[REDIS_CONV] 使用登录用户ID: {login_user_id}")
  56. return login_user_id
  57. # 2. 如果没有登录,尝试从请求参数获取user_id
  58. if user_id_from_request:
  59. print(f"[REDIS_CONV] 使用请求参数user_id: {user_id_from_request}")
  60. return user_id_from_request
  61. # 3. 都没有则为匿名用户(统一为guest)
  62. print(f"[REDIS_CONV] 使用匿名用户: {DEFAULT_ANONYMOUS_USER}")
  63. return DEFAULT_ANONYMOUS_USER
  64. def resolve_conversation_id(self, user_id: str, conversation_id_input: Optional[str],
  65. continue_conversation: bool) -> tuple[str, dict]:
  66. """
  67. 智能解析对话ID - 改进版
  68. Returns:
  69. tuple: (conversation_id, status_info)
  70. status_info包含:
  71. - status: "existing" | "new" | "invalid_id_new"
  72. - message: 状态说明
  73. - requested_id: 原始请求的ID(如果有)
  74. """
  75. # 1. 如果指定了conversation_id,验证后使用
  76. if conversation_id_input:
  77. if self._is_valid_conversation(conversation_id_input, user_id):
  78. print(f"[REDIS_CONV] 使用指定对话: {conversation_id_input}")
  79. return conversation_id_input, {
  80. "status": "existing",
  81. "message": "继续已有对话"
  82. }
  83. else:
  84. print(f"[WARN] 无效的conversation_id: {conversation_id_input},创建新对话")
  85. new_conversation_id = self.create_conversation(user_id)
  86. return new_conversation_id, {
  87. "status": "invalid_id_new",
  88. "message": "您请求的对话不存在或无权访问,已为您创建新对话",
  89. "requested_id": conversation_id_input
  90. }
  91. # 2. 如果要继续最近对话
  92. if continue_conversation:
  93. recent_conversation = self._get_recent_conversation(user_id)
  94. if recent_conversation:
  95. print(f"[REDIS_CONV] 继续最近对话: {recent_conversation}")
  96. return recent_conversation, {
  97. "status": "existing",
  98. "message": "继续最近对话"
  99. }
  100. # 3. 创建新对话
  101. new_conversation_id = self.create_conversation(user_id)
  102. print(f"[REDIS_CONV] 创建新对话: {new_conversation_id}")
  103. return new_conversation_id, {
  104. "status": "new",
  105. "message": "创建新对话"
  106. }
  107. def _is_valid_conversation(self, conversation_id: str, user_id: str) -> bool:
  108. """验证对话是否存在且属于该用户"""
  109. if not self.is_available():
  110. return False
  111. try:
  112. # 检查对话元信息是否存在
  113. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  114. if not meta_data:
  115. return False
  116. # 检查是否属于该用户
  117. return meta_data.get('user_id') == user_id
  118. except Exception:
  119. return False
  120. def _get_recent_conversation(self, user_id: str) -> Optional[str]:
  121. """获取用户最近的对话ID"""
  122. if not self.is_available():
  123. return None
  124. try:
  125. conversations = self.redis_client.lrange(
  126. f"user:{user_id}:conversations", 0, 0
  127. )
  128. return conversations[0] if conversations else None
  129. except Exception:
  130. return None
  131. # ==================== 对话管理 ====================
  132. def create_conversation(self, user_id: str) -> str:
  133. """创建新对话"""
  134. # 生成包含时间戳的conversation_id
  135. timestamp = int(datetime.now().timestamp())
  136. conversation_id = f"conv_{timestamp}_{uuid.uuid4().hex[:8]}"
  137. if not self.is_available():
  138. return conversation_id # Redis不可用时返回ID,但不存储
  139. try:
  140. # 创建对话元信息
  141. meta_data = {
  142. "conversation_id": conversation_id,
  143. "user_id": user_id,
  144. "created_at": datetime.now().isoformat(),
  145. "updated_at": datetime.now().isoformat(),
  146. "message_count": "0"
  147. }
  148. # 保存对话元信息
  149. self.redis_client.hset(
  150. f"conversation:{conversation_id}:meta",
  151. mapping=meta_data
  152. )
  153. self.redis_client.expire(f"conversation:{conversation_id}:meta", CONVERSATION_TTL)
  154. # 添加到用户的对话列表
  155. self._add_conversation_to_user(user_id, conversation_id)
  156. print(f"[REDIS_CONV] 创建对话成功: {conversation_id}")
  157. return conversation_id
  158. except Exception as e:
  159. print(f"[ERROR] 创建对话失败: {str(e)}")
  160. return conversation_id # 返回ID但可能未存储
  161. def save_message(self, conversation_id: str, role: str, content: str,
  162. metadata: Optional[Dict] = None) -> bool:
  163. """保存消息到对话历史"""
  164. if not self.is_available() or not conversation_id:
  165. return False
  166. try:
  167. message_data = {
  168. "message_id": str(uuid.uuid4()),
  169. "timestamp": datetime.now().isoformat(),
  170. "role": role, # user, assistant
  171. "content": content,
  172. "metadata": metadata or {}
  173. }
  174. # 保存到消息列表(LPUSH添加到头部,最新消息在前)
  175. self.redis_client.lpush(
  176. f"conversation:{conversation_id}:messages",
  177. json.dumps(message_data)
  178. )
  179. # 设置TTL
  180. self.redis_client.expire(f"conversation:{conversation_id}:messages", CONVERSATION_TTL)
  181. # 限制消息数量
  182. self.redis_client.ltrim(
  183. f"conversation:{conversation_id}:messages",
  184. 0, CONVERSATION_MAX_LENGTH - 1
  185. )
  186. # 更新元信息
  187. self._update_conversation_meta(conversation_id)
  188. return True
  189. except Exception as e:
  190. print(f"[ERROR] 保存消息失败: {str(e)}")
  191. return False
  192. def get_context(self, conversation_id: str, count: Optional[int] = None) -> str:
  193. """获取对话上下文,格式化为prompt"""
  194. if not self.is_available() or not ENABLE_CONVERSATION_CONTEXT:
  195. return ""
  196. try:
  197. if count is None:
  198. count = CONVERSATION_CONTEXT_COUNT
  199. # 获取最近的消息(count*2 因为包含用户和助手消息)
  200. message_count = count * 2
  201. messages = self.redis_client.lrange(
  202. f"conversation:{conversation_id}:messages",
  203. 0, message_count - 1
  204. )
  205. if not messages:
  206. return ""
  207. # 解析消息并构建上下文(按时间正序)
  208. context_parts = []
  209. for msg_json in reversed(messages): # Redis返回倒序,需要反转
  210. try:
  211. msg_data = json.loads(msg_json)
  212. role = msg_data.get("role", "")
  213. content = msg_data.get("content", "")
  214. if role == "user":
  215. context_parts.append(f"User: {content}")
  216. elif role == "assistant":
  217. context_parts.append(f"Assistant: {content}")
  218. except json.JSONDecodeError:
  219. continue
  220. context = "\n".join(context_parts)
  221. print(f"[REDIS_CONV] 获取上下文成功: {len(context_parts)}条消息")
  222. return context
  223. except Exception as e:
  224. print(f"[ERROR] 获取上下文失败: {str(e)}")
  225. return ""
  226. def get_conversation_messages(self, conversation_id: str, limit: Optional[int] = None) -> List[Dict]:
  227. """获取对话的消息列表"""
  228. if not self.is_available():
  229. return []
  230. try:
  231. if limit:
  232. messages = self.redis_client.lrange(
  233. f"conversation:{conversation_id}:messages", 0, limit - 1
  234. )
  235. else:
  236. messages = self.redis_client.lrange(
  237. f"conversation:{conversation_id}:messages", 0, -1
  238. )
  239. # 解析并按时间正序返回
  240. parsed_messages = []
  241. for msg_json in reversed(messages): # 反转为时间正序
  242. try:
  243. parsed_messages.append(json.loads(msg_json))
  244. except json.JSONDecodeError:
  245. continue
  246. return parsed_messages
  247. except Exception as e:
  248. print(f"[ERROR] 获取对话消息失败: {str(e)}")
  249. return []
  250. def get_conversation_meta(self, conversation_id: str) -> Dict:
  251. """获取对话元信息"""
  252. if not self.is_available():
  253. return {}
  254. try:
  255. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  256. return meta_data if meta_data else {}
  257. except Exception as e:
  258. print(f"[ERROR] 获取对话元信息失败: {str(e)}")
  259. return {}
  260. def get_conversations(self, user_id: str, limit: int = None) -> List[Dict]:
  261. """获取用户的对话列表(按时间倒序)"""
  262. if not self.is_available():
  263. return []
  264. if limit is None:
  265. limit = USER_MAX_CONVERSATIONS
  266. try:
  267. # 获取对话ID列表(已经按时间倒序)
  268. conversation_ids = self.redis_client.lrange(
  269. f"user:{user_id}:conversations", 0, limit - 1
  270. )
  271. conversations = []
  272. for conv_id in conversation_ids:
  273. meta_data = self.get_conversation_meta(conv_id)
  274. if meta_data: # 只返回仍然存在的对话
  275. conversations.append(meta_data)
  276. return conversations
  277. except Exception as e:
  278. print(f"[ERROR] 获取用户对话列表失败: {str(e)}")
  279. return []
  280. # ==================== 智能缓存(修正版)====================
  281. def get_cached_answer(self, question: str, context: str = "") -> Optional[Dict]:
  282. """检查问答缓存 - 放宽使用条件,无论是否有上下文都可以查找缓存"""
  283. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  284. return None
  285. # 移除上下文检查,允许任何情况下查找缓存
  286. try:
  287. cache_key = self._get_cache_key(question)
  288. cached_answer = self.redis_client.get(cache_key)
  289. if cached_answer:
  290. context_info = "有上下文" if context else "无上下文"
  291. print(f"[REDIS_CONV] 缓存命中: {cache_key} ({context_info})")
  292. return json.loads(cached_answer)
  293. return None
  294. except Exception as e:
  295. print(f"[ERROR] 获取缓存答案失败: {str(e)}")
  296. return None
  297. def cache_answer(self, question: str, answer: Dict, context: str = ""):
  298. """缓存问答结果 - 只缓存无上下文的问答"""
  299. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  300. return
  301. # 新增:如果有上下文,不缓存
  302. if context:
  303. print(f"[REDIS_CONV] 跳过缓存存储:存在上下文")
  304. return
  305. try:
  306. cache_key = self._get_cache_key(question)
  307. # 添加缓存时间戳
  308. answer_with_meta = {
  309. **answer,
  310. "cached_at": datetime.now().isoformat(),
  311. "original_question": question
  312. }
  313. # 使用独立key,每个缓存项单独设置TTL
  314. self.redis_client.setex(
  315. cache_key,
  316. QUESTION_ANSWER_TTL,
  317. json.dumps(answer_with_meta)
  318. )
  319. print(f"[REDIS_CONV] 缓存答案成功: {cache_key}")
  320. except Exception as e:
  321. print(f"[ERROR] 缓存答案失败: {str(e)}")
  322. def _get_cache_key(self, question: str) -> str:
  323. """生成缓存键 - 简化版,只基于问题本身"""
  324. normalized = question.strip().lower()
  325. question_hash = hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
  326. return f"qa_cache:{question_hash}"
  327. # ==================== 私有方法 ====================
  328. def _add_conversation_to_user(self, user_id: str, conversation_id: str):
  329. """添加对话到用户列表,按时间自动排序"""
  330. try:
  331. # LPUSH添加到列表头部(最新的)
  332. self.redis_client.lpush(f"user:{user_id}:conversations", conversation_id)
  333. # 统一限制数量
  334. self.redis_client.ltrim(
  335. f"user:{user_id}:conversations",
  336. 0, USER_MAX_CONVERSATIONS - 1
  337. )
  338. # 统一设置TTL
  339. self.redis_client.expire(
  340. f"user:{user_id}:conversations",
  341. USER_CONVERSATIONS_TTL
  342. )
  343. except Exception as e:
  344. print(f"[ERROR] 添加对话到用户列表失败: {str(e)}")
  345. def _update_conversation_meta(self, conversation_id: str):
  346. """更新对话元信息"""
  347. try:
  348. # 获取消息数量
  349. message_count = self.redis_client.llen(f"conversation:{conversation_id}:messages")
  350. # 更新元信息
  351. self.redis_client.hset(
  352. f"conversation:{conversation_id}:meta",
  353. mapping={
  354. "updated_at": datetime.now().isoformat(),
  355. "message_count": str(message_count)
  356. }
  357. )
  358. except Exception as e:
  359. print(f"[ERROR] 更新对话元信息失败: {str(e)}")
  360. # ==================== 管理方法 ====================
  361. def get_stats(self) -> Dict:
  362. """获取统计信息"""
  363. if not self.is_available():
  364. return {"available": False}
  365. try:
  366. # 获取Redis原始内存信息
  367. redis_info = self.redis_client.info()
  368. used_memory_bytes = redis_info.get("used_memory", 0)
  369. stats = {
  370. "available": True,
  371. "total_users": len(self.redis_client.keys("user:*:conversations")),
  372. "total_conversations": len(self.redis_client.keys("conversation:*:meta")),
  373. "cached_qa_count": len(self.redis_client.keys("qa_cache:*")), # 修正缓存统计
  374. "redis_info": {
  375. "memory_usage_mb": round(used_memory_bytes / (1024 * 1024), 2), # 统一使用MB格式
  376. "connected_clients": redis_info.get("connected_clients")
  377. }
  378. }
  379. return stats
  380. except Exception as e:
  381. print(f"[ERROR] 获取统计信息失败: {str(e)}")
  382. return {"available": False, "error": str(e)}
  383. def cleanup_expired_conversations(self):
  384. """清理过期对话(Redis TTL自动处理,这里可添加额外逻辑)"""
  385. if not self.is_available():
  386. return
  387. try:
  388. # 清理用户对话列表中的无效对话ID
  389. user_keys = self.redis_client.keys("user:*:conversations")
  390. cleaned_count = 0
  391. for user_key in user_keys:
  392. conversation_ids = self.redis_client.lrange(user_key, 0, -1)
  393. valid_ids = []
  394. for conv_id in conversation_ids:
  395. # 检查对话是否仍然存在
  396. if self.redis_client.exists(f"conversation:{conv_id}:meta"):
  397. valid_ids.append(conv_id)
  398. else:
  399. cleaned_count += 1
  400. # 如果有无效ID,重建列表
  401. if len(valid_ids) != len(conversation_ids):
  402. self.redis_client.delete(user_key)
  403. if valid_ids:
  404. self.redis_client.lpush(user_key, *reversed(valid_ids))
  405. # 重新设置TTL
  406. self.redis_client.expire(user_key, USER_CONVERSATIONS_TTL)
  407. print(f"[REDIS_CONV] 清理完成,移除了 {cleaned_count} 个无效对话引用")
  408. except Exception as e:
  409. print(f"[ERROR] 清理失败: {str(e)}")
  410. # ==================== 问答缓存管理方法 ====================
  411. def get_qa_cache_stats(self) -> Dict:
  412. """获取问答缓存详细统计信息"""
  413. if not self.is_available():
  414. return {"available": False}
  415. try:
  416. pattern = "qa_cache:*"
  417. keys = self.redis_client.keys(pattern)
  418. stats = {
  419. "available": True,
  420. "enabled": ENABLE_QUESTION_ANSWER_CACHE,
  421. "total_count": len(keys),
  422. "memory_usage_mb": 0,
  423. "ttl_seconds": QUESTION_ANSWER_TTL,
  424. "ttl_hours": QUESTION_ANSWER_TTL / 3600
  425. }
  426. # 估算内存使用量
  427. if keys:
  428. sample_key = keys[0]
  429. sample_data = self.redis_client.get(sample_key)
  430. if sample_data:
  431. avg_size_bytes = len(sample_data.encode('utf-8'))
  432. total_size_bytes = avg_size_bytes * len(keys)
  433. stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
  434. return stats
  435. except Exception as e:
  436. print(f"[ERROR] 获取问答缓存统计失败: {str(e)}")
  437. return {"available": False, "error": str(e)}
  438. def get_qa_cache_list(self, limit: int = 50) -> List[Dict]:
  439. """获取问答缓存列表(支持分页)"""
  440. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  441. return []
  442. try:
  443. pattern = "qa_cache:*"
  444. keys = self.redis_client.keys(pattern)
  445. # 限制返回数量
  446. if limit > 0:
  447. keys = keys[:limit]
  448. cache_list = []
  449. for key in keys:
  450. try:
  451. cached_data = self.redis_client.get(key)
  452. if cached_data:
  453. data = json.loads(cached_data)
  454. # 获取TTL信息
  455. ttl = self.redis_client.ttl(key)
  456. cache_item = {
  457. "cache_key": key,
  458. "question": data.get("original_question", "未知问题"),
  459. "cached_at": data.get("cached_at", "未知时间"),
  460. "ttl_seconds": ttl,
  461. "response_type": data.get("data", {}).get("type", "未知类型"),
  462. "has_sql": bool(data.get("data", {}).get("sql")),
  463. "has_summary": bool(data.get("data", {}).get("summary"))
  464. }
  465. cache_list.append(cache_item)
  466. except json.JSONDecodeError:
  467. # 跳过无效的JSON数据
  468. continue
  469. except Exception as e:
  470. print(f"[WARNING] 处理缓存项 {key} 失败: {e}")
  471. continue
  472. # 按缓存时间倒序排列
  473. cache_list.sort(key=lambda x: x.get("cached_at", ""), reverse=True)
  474. return cache_list
  475. except Exception as e:
  476. print(f"[ERROR] 获取问答缓存列表失败: {str(e)}")
  477. return []
  478. def clear_all_qa_cache(self) -> int:
  479. """清空所有问答缓存,返回删除的数量"""
  480. if not self.is_available():
  481. return 0
  482. try:
  483. pattern = "qa_cache:*"
  484. keys = self.redis_client.keys(pattern)
  485. if keys:
  486. deleted_count = self.redis_client.delete(*keys)
  487. print(f"[REDIS_CONV] 清空问答缓存成功,删除了 {deleted_count} 个缓存项")
  488. return deleted_count
  489. else:
  490. print(f"[REDIS_CONV] 没有找到问答缓存项")
  491. return 0
  492. except Exception as e:
  493. print(f"[ERROR] 清空问答缓存失败: {str(e)}")
  494. return 0