redis_conversation_manager.py 20 KB


  1. import redis
  2. import json
  3. import hashlib
  4. import uuid
  5. import time
  6. from datetime import datetime
  7. from typing import List, Dict, Any, Optional
  8. from app_config import (
  9. REDIS_HOST, REDIS_PORT, REDIS_DB, REDIS_PASSWORD,
  10. CONVERSATION_CONTEXT_COUNT, CONVERSATION_MAX_LENGTH, USER_MAX_CONVERSATIONS,
  11. CONVERSATION_TTL, USER_CONVERSATIONS_TTL, QUESTION_ANSWER_TTL,
  12. ENABLE_CONVERSATION_CONTEXT, ENABLE_QUESTION_ANSWER_CACHE,
  13. DEFAULT_ANONYMOUS_USER_PREFIX, MAX_GUEST_CONVERSATIONS, MAX_REGISTERED_CONVERSATIONS,
  14. GUEST_USER_TTL
  15. )
  16. class RedisConversationManager:
  17. """Redis对话管理器 - 修正版"""
  18. def __init__(self):
  19. """初始化Redis连接"""
  20. try:
  21. self.redis_client = redis.Redis(
  22. host=REDIS_HOST,
  23. port=REDIS_PORT,
  24. db=REDIS_DB,
  25. password=REDIS_PASSWORD,
  26. decode_responses=True,
  27. socket_connect_timeout=5,
  28. socket_timeout=5
  29. )
  30. # 测试连接
  31. self.redis_client.ping()
  32. print(f"[REDIS_CONV] Redis连接成功: {REDIS_HOST}:{REDIS_PORT}")
  33. except Exception as e:
  34. print(f"[ERROR] Redis连接失败: {str(e)}")
  35. self.redis_client = None
  36. def is_available(self) -> bool:
  37. """检查Redis是否可用"""
  38. try:
  39. return self.redis_client is not None and self.redis_client.ping()
  40. except:
  41. return False
  42. # ==================== 用户ID解析(修正版)====================
  43. def resolve_user_id(self, user_id_from_request: Optional[str],
  44. session_id: Optional[str], request_ip: str,
  45. login_user_id: Optional[str] = None) -> str:
  46. """
  47. 智能解析用户ID - 修正版
  48. Args:
  49. user_id_from_request: 请求参数中的user_id
  50. session_id: 浏览器session_id
  51. request_ip: 请求IP地址
  52. login_user_id: 从Flask session中获取的登录用户ID(在ask_agent中获取)
  53. """
  54. # 1. 优先使用登录用户ID
  55. if login_user_id:
  56. print(f"[REDIS_CONV] 使用登录用户ID: {login_user_id}")
  57. return login_user_id
  58. # 2. 如果没有登录,尝试从请求参数获取user_id
  59. if user_id_from_request:
  60. print(f"[REDIS_CONV] 使用请求参数user_id: {user_id_from_request}")
  61. return user_id_from_request
  62. # 3. 都没有则为匿名用户(guest)
  63. if session_id:
  64. guest_suffix = hashlib.md5(session_id.encode()).hexdigest()[:8]
  65. guest_id = f"{DEFAULT_ANONYMOUS_USER_PREFIX}_{guest_suffix}"
  66. print(f"[REDIS_CONV] 生成稳定guest用户: {guest_id}")
  67. return guest_id
  68. # 4. 最后基于IP的临时guest ID
  69. ip_suffix = hashlib.md5(request_ip.encode()).hexdigest()[:8]
  70. temp_guest_id = f"{DEFAULT_ANONYMOUS_USER_PREFIX}_temp_{ip_suffix}"
  71. print(f"[REDIS_CONV] 生成临时guest用户: {temp_guest_id}")
  72. return temp_guest_id
  73. def resolve_conversation_id(self, user_id: str, conversation_id_input: Optional[str],
  74. continue_conversation: bool) -> tuple[str, dict]:
  75. """
  76. 智能解析对话ID - 改进版
  77. Returns:
  78. tuple: (conversation_id, status_info)
  79. status_info包含:
  80. - status: "existing" | "new" | "invalid_id_new"
  81. - message: 状态说明
  82. - requested_id: 原始请求的ID(如果有)
  83. """
  84. # 1. 如果指定了conversation_id,验证后使用
  85. if conversation_id_input:
  86. if self._is_valid_conversation(conversation_id_input, user_id):
  87. print(f"[REDIS_CONV] 使用指定对话: {conversation_id_input}")
  88. return conversation_id_input, {
  89. "status": "existing",
  90. "message": "继续已有对话"
  91. }
  92. else:
  93. print(f"[WARN] 无效的conversation_id: {conversation_id_input},创建新对话")
  94. new_conversation_id = self.create_conversation(user_id)
  95. return new_conversation_id, {
  96. "status": "invalid_id_new",
  97. "message": "您请求的对话不存在或无权访问,已为您创建新对话",
  98. "requested_id": conversation_id_input
  99. }
  100. # 2. 如果要继续最近对话
  101. if continue_conversation:
  102. recent_conversation = self._get_recent_conversation(user_id)
  103. if recent_conversation:
  104. print(f"[REDIS_CONV] 继续最近对话: {recent_conversation}")
  105. return recent_conversation, {
  106. "status": "existing",
  107. "message": "继续最近对话"
  108. }
  109. # 3. 创建新对话
  110. new_conversation_id = self.create_conversation(user_id)
  111. print(f"[REDIS_CONV] 创建新对话: {new_conversation_id}")
  112. return new_conversation_id, {
  113. "status": "new",
  114. "message": "创建新对话"
  115. }
  116. def _is_valid_conversation(self, conversation_id: str, user_id: str) -> bool:
  117. """验证对话是否存在且属于该用户"""
  118. if not self.is_available():
  119. return False
  120. try:
  121. # 检查对话元信息是否存在
  122. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  123. if not meta_data:
  124. return False
  125. # 检查是否属于该用户
  126. return meta_data.get('user_id') == user_id
  127. except Exception:
  128. return False
  129. def _get_recent_conversation(self, user_id: str) -> Optional[str]:
  130. """获取用户最近的对话ID"""
  131. if not self.is_available():
  132. return None
  133. try:
  134. conversations = self.redis_client.lrange(
  135. f"user:{user_id}:conversations", 0, 0
  136. )
  137. return conversations[0] if conversations else None
  138. except Exception:
  139. return None
  140. # ==================== 对话管理 ====================
  141. def create_conversation(self, user_id: str) -> str:
  142. """创建新对话"""
  143. # 生成包含时间戳的conversation_id
  144. timestamp = int(datetime.now().timestamp())
  145. conversation_id = f"conv_{timestamp}_{uuid.uuid4().hex[:8]}"
  146. if not self.is_available():
  147. return conversation_id # Redis不可用时返回ID,但不存储
  148. try:
  149. # 创建对话元信息
  150. meta_data = {
  151. "conversation_id": conversation_id,
  152. "user_id": user_id,
  153. "created_at": datetime.now().isoformat(),
  154. "updated_at": datetime.now().isoformat(),
  155. "message_count": "0"
  156. }
  157. # 保存对话元信息
  158. self.redis_client.hset(
  159. f"conversation:{conversation_id}:meta",
  160. mapping=meta_data
  161. )
  162. self.redis_client.expire(f"conversation:{conversation_id}:meta", CONVERSATION_TTL)
  163. # 添加到用户的对话列表
  164. self._add_conversation_to_user(user_id, conversation_id)
  165. print(f"[REDIS_CONV] 创建对话成功: {conversation_id}")
  166. return conversation_id
  167. except Exception as e:
  168. print(f"[ERROR] 创建对话失败: {str(e)}")
  169. return conversation_id # 返回ID但可能未存储
  170. def save_message(self, conversation_id: str, role: str, content: str,
  171. metadata: Optional[Dict] = None) -> bool:
  172. """保存消息到对话历史"""
  173. if not self.is_available() or not conversation_id:
  174. return False
  175. try:
  176. message_data = {
  177. "message_id": str(uuid.uuid4()),
  178. "timestamp": datetime.now().isoformat(),
  179. "role": role, # user, assistant
  180. "content": content,
  181. "metadata": metadata or {}
  182. }
  183. # 保存到消息列表(LPUSH添加到头部,最新消息在前)
  184. self.redis_client.lpush(
  185. f"conversation:{conversation_id}:messages",
  186. json.dumps(message_data)
  187. )
  188. # 设置TTL
  189. self.redis_client.expire(f"conversation:{conversation_id}:messages", CONVERSATION_TTL)
  190. # 限制消息数量
  191. self.redis_client.ltrim(
  192. f"conversation:{conversation_id}:messages",
  193. 0, CONVERSATION_MAX_LENGTH - 1
  194. )
  195. # 更新元信息
  196. self._update_conversation_meta(conversation_id)
  197. return True
  198. except Exception as e:
  199. print(f"[ERROR] 保存消息失败: {str(e)}")
  200. return False
  201. def get_context(self, conversation_id: str, count: Optional[int] = None) -> str:
  202. """获取对话上下文,格式化为prompt"""
  203. if not self.is_available() or not ENABLE_CONVERSATION_CONTEXT:
  204. return ""
  205. try:
  206. if count is None:
  207. count = CONVERSATION_CONTEXT_COUNT
  208. # 获取最近的消息(count*2 因为包含用户和助手消息)
  209. message_count = count * 2
  210. messages = self.redis_client.lrange(
  211. f"conversation:{conversation_id}:messages",
  212. 0, message_count - 1
  213. )
  214. if not messages:
  215. return ""
  216. # 解析消息并构建上下文(按时间正序)
  217. context_parts = []
  218. for msg_json in reversed(messages): # Redis返回倒序,需要反转
  219. try:
  220. msg_data = json.loads(msg_json)
  221. role = msg_data.get("role", "")
  222. content = msg_data.get("content", "")
  223. if role == "user":
  224. context_parts.append(f"用户: {content}")
  225. elif role == "assistant":
  226. context_parts.append(f"助手: {content}")
  227. except json.JSONDecodeError:
  228. continue
  229. context = "\n".join(context_parts)
  230. print(f"[REDIS_CONV] 获取上下文成功: {len(context_parts)}条消息")
  231. return context
  232. except Exception as e:
  233. print(f"[ERROR] 获取上下文失败: {str(e)}")
  234. return ""
  235. def get_conversation_messages(self, conversation_id: str, limit: Optional[int] = None) -> List[Dict]:
  236. """获取对话的消息列表"""
  237. if not self.is_available():
  238. return []
  239. try:
  240. if limit:
  241. messages = self.redis_client.lrange(
  242. f"conversation:{conversation_id}:messages", 0, limit - 1
  243. )
  244. else:
  245. messages = self.redis_client.lrange(
  246. f"conversation:{conversation_id}:messages", 0, -1
  247. )
  248. # 解析并按时间正序返回
  249. parsed_messages = []
  250. for msg_json in reversed(messages): # 反转为时间正序
  251. try:
  252. parsed_messages.append(json.loads(msg_json))
  253. except json.JSONDecodeError:
  254. continue
  255. return parsed_messages
  256. except Exception as e:
  257. print(f"[ERROR] 获取对话消息失败: {str(e)}")
  258. return []
  259. def get_conversation_meta(self, conversation_id: str) -> Dict:
  260. """获取对话元信息"""
  261. if not self.is_available():
  262. return {}
  263. try:
  264. meta_data = self.redis_client.hgetall(f"conversation:{conversation_id}:meta")
  265. return meta_data if meta_data else {}
  266. except Exception as e:
  267. print(f"[ERROR] 获取对话元信息失败: {str(e)}")
  268. return {}
  269. def get_conversations(self, user_id: str, limit: int = None) -> List[Dict]:
  270. """获取用户的对话列表(按时间倒序)"""
  271. if not self.is_available():
  272. return []
  273. if limit is None:
  274. limit = USER_MAX_CONVERSATIONS
  275. try:
  276. # 获取对话ID列表(已经按时间倒序)
  277. conversation_ids = self.redis_client.lrange(
  278. f"user:{user_id}:conversations", 0, limit - 1
  279. )
  280. conversations = []
  281. for conv_id in conversation_ids:
  282. meta_data = self.get_conversation_meta(conv_id)
  283. if meta_data: # 只返回仍然存在的对话
  284. conversations.append(meta_data)
  285. return conversations
  286. except Exception as e:
  287. print(f"[ERROR] 获取用户对话列表失败: {str(e)}")
  288. return []
  289. # ==================== 智能缓存(修正版)====================
  290. def get_cached_answer(self, question: str, context: str = "") -> Optional[Dict]:
  291. """检查问答缓存 - 真正上下文感知版"""
  292. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  293. return None
  294. try:
  295. cache_key = self._get_cache_key(question, context)
  296. cached_answer = self.redis_client.get(cache_key) # 使用独立key而不是hash
  297. if cached_answer:
  298. print(f"[REDIS_CONV] 缓存命中: {cache_key}")
  299. return json.loads(cached_answer)
  300. return None
  301. except Exception as e:
  302. print(f"[ERROR] 获取缓存答案失败: {str(e)}")
  303. return None
  304. def cache_answer(self, question: str, answer: Dict, context: str = ""):
  305. """缓存问答结果 - 真正上下文感知版"""
  306. if not self.is_available() or not ENABLE_QUESTION_ANSWER_CACHE:
  307. return
  308. try:
  309. cache_key = self._get_cache_key(question, context)
  310. # 添加缓存时间戳和上下文哈希
  311. answer_with_meta = {
  312. **answer,
  313. "cached_at": datetime.now().isoformat(),
  314. "original_question": question,
  315. "context_hash": hashlib.md5(context.encode()).hexdigest()[:8] if context else ""
  316. }
  317. # 使用独立key,每个缓存项单独设置TTL
  318. self.redis_client.setex(
  319. cache_key,
  320. QUESTION_ANSWER_TTL,
  321. json.dumps(answer_with_meta)
  322. )
  323. print(f"[REDIS_CONV] 缓存答案成功: {cache_key}")
  324. except Exception as e:
  325. print(f"[ERROR] 缓存答案失败: {str(e)}")
  326. def _get_cache_key(self, question: str, context: str = "") -> str:
  327. """生成真正包含上下文的缓存键"""
  328. if context and ENABLE_CONVERSATION_CONTEXT:
  329. # 使用上下文内容而不是conversation_id
  330. cache_input = f"context:{context}\nquestion:{question}"
  331. else:
  332. cache_input = question
  333. normalized = cache_input.strip().lower()
  334. question_hash = hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
  335. return f"qa_cache:{question_hash}"
  336. # ==================== 私有方法 ====================
  337. def _add_conversation_to_user(self, user_id: str, conversation_id: str):
  338. """添加对话到用户列表,按时间自动排序"""
  339. try:
  340. # 获取用户类型配置
  341. config = self._get_user_type_config(user_id)
  342. # LPUSH添加到列表头部(最新的)
  343. self.redis_client.lpush(f"user:{user_id}:conversations", conversation_id)
  344. # 根据用户类型限制数量
  345. self.redis_client.ltrim(
  346. f"user:{user_id}:conversations",
  347. 0, config["max_conversations"] - 1
  348. )
  349. # 设置TTL
  350. self.redis_client.expire(
  351. f"user:{user_id}:conversations",
  352. config["ttl"]
  353. )
  354. except Exception as e:
  355. print(f"[ERROR] 添加对话到用户列表失败: {str(e)}")
  356. def _get_user_type_config(self, user_id: str) -> Dict:
  357. """根据用户类型返回不同的配置 - 修正版"""
  358. if user_id.startswith(DEFAULT_ANONYMOUS_USER_PREFIX):
  359. return {
  360. "max_conversations": MAX_GUEST_CONVERSATIONS,
  361. "ttl": GUEST_USER_TTL # 使用专门的guest TTL
  362. }
  363. else:
  364. return {
  365. "max_conversations": MAX_REGISTERED_CONVERSATIONS,
  366. "ttl": USER_CONVERSATIONS_TTL
  367. }
  368. def _update_conversation_meta(self, conversation_id: str):
  369. """更新对话元信息"""
  370. try:
  371. # 获取消息数量
  372. message_count = self.redis_client.llen(f"conversation:{conversation_id}:messages")
  373. # 更新元信息
  374. self.redis_client.hset(
  375. f"conversation:{conversation_id}:meta",
  376. mapping={
  377. "updated_at": datetime.now().isoformat(),
  378. "message_count": str(message_count)
  379. }
  380. )
  381. except Exception as e:
  382. print(f"[ERROR] 更新对话元信息失败: {str(e)}")
  383. # ==================== 管理方法 ====================
  384. def get_stats(self) -> Dict:
  385. """获取统计信息"""
  386. if not self.is_available():
  387. return {"available": False}
  388. try:
  389. stats = {
  390. "available": True,
  391. "total_users": len(self.redis_client.keys("user:*:conversations")),
  392. "total_conversations": len(self.redis_client.keys("conversation:*:meta")),
  393. "cached_qa_count": len(self.redis_client.keys("qa_cache:*")), # 修正缓存统计
  394. "redis_info": {
  395. "used_memory": self.redis_client.info().get("used_memory_human"),
  396. "connected_clients": self.redis_client.info().get("connected_clients")
  397. }
  398. }
  399. return stats
  400. except Exception as e:
  401. print(f"[ERROR] 获取统计信息失败: {str(e)}")
  402. return {"available": False, "error": str(e)}
  403. def cleanup_expired_conversations(self):
  404. """清理过期对话(Redis TTL自动处理,这里可添加额外逻辑)"""
  405. if not self.is_available():
  406. return
  407. try:
  408. # 清理用户对话列表中的无效对话ID
  409. user_keys = self.redis_client.keys("user:*:conversations")
  410. cleaned_count = 0
  411. for user_key in user_keys:
  412. conversation_ids = self.redis_client.lrange(user_key, 0, -1)
  413. valid_ids = []
  414. for conv_id in conversation_ids:
  415. # 检查对话是否仍然存在
  416. if self.redis_client.exists(f"conversation:{conv_id}:meta"):
  417. valid_ids.append(conv_id)
  418. else:
  419. cleaned_count += 1
  420. # 如果有无效ID,重建列表
  421. if len(valid_ids) != len(conversation_ids):
  422. self.redis_client.delete(user_key)
  423. if valid_ids:
  424. self.redis_client.lpush(user_key, *reversed(valid_ids))
  425. # 重新设置TTL
  426. self.redis_client.expire(user_key, USER_CONVERSATIONS_TTL)
  427. print(f"[REDIS_CONV] 清理完成,移除了 {cleaned_count} 个无效对话引用")
  428. except Exception as e:
  429. print(f"[ERROR] 清理失败: {str(e)}")