session_aware_cache.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # 简化后的对话感知缓存
  2. from datetime import datetime
  3. from vanna.flask import MemoryCache
  4. import uuid
  5. class ConversationAwareMemoryCache(MemoryCache):
  6. """基于对话ID的简单时间感知缓存实现"""
  7. def __init__(self):
  8. super().__init__()
  9. self.conversation_start_times = {} # 每个对话的开始时间: {conversation_id: datetime}
  10. def generate_id(self, question: str = None, user_id: str = None) -> str:
  11. """生成对话ID并记录时间,格式为 {user_id}:YYYYMMDDHHMMSSsss"""
  12. # 如果没有传递user_id,使用默认值
  13. if not user_id:
  14. user_id = "guest"
  15. # 生成时间戳:年月日时分秒毫秒格式
  16. now = datetime.now()
  17. timestamp = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
  18. # 生成对话ID:{user_id}:{timestamp}
  19. conversation_id = f"{user_id}:{timestamp}"
  20. # 记录对话开始时间
  21. self.conversation_start_times[conversation_id] = now
  22. return conversation_id
  23. def set(self, id: str, field: str, value, **kwargs):
  24. """重载set方法,确保时间信息正确"""
  25. # 如果这是新对话,初始化时间信息
  26. if id not in self.conversation_start_times:
  27. self.conversation_start_times[id] = datetime.now()
  28. # 调用父类的set方法
  29. super().set(id=id, field=field, value=value)
  30. # 自动设置对话开始时间字段
  31. if field != 'conversation_start_time':
  32. super().set(id=id, field='conversation_start_time',
  33. value=self.conversation_start_times[id])
  34. def get_conversation_start_time(self, conversation_id: str) -> datetime:
  35. """获取对话开始时间"""
  36. return self.conversation_start_times.get(conversation_id)
  37. def get_conversation_info(self, conversation_id: str):
  38. """获取对话信息"""
  39. start_time = self.get_conversation_start_time(conversation_id)
  40. if start_time:
  41. duration = datetime.now() - start_time
  42. # 从conversation_id解析user_id
  43. user_id = "unknown"
  44. if ":" in conversation_id:
  45. user_id = conversation_id.split(":")[0]
  46. return {
  47. 'conversation_id': conversation_id,
  48. 'user_id': user_id,
  49. 'start_time': start_time,
  50. 'duration_seconds': duration.total_seconds(),
  51. 'duration_formatted': str(duration)
  52. }
  53. return None
  54. def get_all_conversations(self):
  55. """获取所有对话信息"""
  56. result = {}
  57. for conversation_id, start_time in self.conversation_start_times.items():
  58. duration = datetime.now() - start_time
  59. # 从conversation_id解析user_id
  60. user_id = "unknown"
  61. if ":" in conversation_id:
  62. user_id = conversation_id.split(":")[0]
  63. result[conversation_id] = {
  64. 'user_id': user_id,
  65. 'start_time': start_time,
  66. 'duration_seconds': duration.total_seconds(),
  67. 'duration_formatted': str(duration)
  68. }
  69. return result
  70. @staticmethod
  71. def parse_conversation_id(conversation_id: str):
  72. """解析conversation_id,返回user_id和timestamp"""
  73. if ":" not in conversation_id:
  74. return None, None
  75. parts = conversation_id.split(":", 1)
  76. user_id = parts[0]
  77. timestamp_str = parts[1]
  78. try:
  79. # 解析时间戳:YYYYMMDDHHMMSSsss
  80. if len(timestamp_str) == 17: # 20250722204550155
  81. timestamp = datetime.strptime(timestamp_str[:14], "%Y%m%d%H%M%S")
  82. # 添加毫秒
  83. milliseconds = int(timestamp_str[14:])
  84. timestamp = timestamp.replace(microsecond=milliseconds * 1000)
  85. return user_id, timestamp
  86. except ValueError:
  87. pass
  88. return user_id, None
  89. @staticmethod
  90. def extract_user_id(conversation_id: str) -> str:
  91. """从conversation_id中提取user_id"""
  92. if ":" not in conversation_id:
  93. return "unknown"
  94. return conversation_id.split(":", 1)[0]
  95. @staticmethod
  96. def validate_user_id_consistency(conversation_id: str, provided_user_id: str) -> tuple[bool, str]:
  97. """
  98. 校验conversation_id中的user_id与提供的user_id是否一致
  99. Returns:
  100. tuple: (is_valid, error_message)
  101. """
  102. if not conversation_id or not provided_user_id:
  103. return True, "" # 如果任一为空,跳过校验
  104. extracted_user_id = ConversationAwareMemoryCache.extract_user_id(conversation_id)
  105. if extracted_user_id != provided_user_id:
  106. return False, f"用户ID不匹配:conversation_id中的用户ID '{extracted_user_id}' 与提供的用户ID '{provided_user_id}' 不一致"
  107. return True, ""
  108. # 保持向后兼容的别名
  109. WebSessionAwareMemoryCache = ConversationAwareMemoryCache
  110. SessionAwareMemoryCache = ConversationAwareMemoryCache