session_aware_cache.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # 修正后的 custom_cache.py
  2. from datetime import datetime
  3. from vanna.flask import MemoryCache
  4. import uuid
  5. class SessionAwareMemoryCache(MemoryCache):
  6. """区分会话(Session)和对话(Conversation)的缓存实现"""
  7. def __init__(self):
  8. super().__init__()
  9. self.conversation_start_times = {} # 每个对话的开始时间
  10. self.session_info = {} # 会话信息: {session_id: {'start_time': datetime, 'conversations': []}}
  11. self.conversation_to_session = {} # 对话ID到会话ID的映射
  12. def create_or_get_session_id(self, user_identifier=None):
  13. """
  14. 创建或获取会话ID
  15. 在实际应用中,这可以通过以下方式确定:
  16. 1. HTTP请求中的session cookie
  17. 2. JWT token中的session信息
  18. 3. 前端传递的session_id
  19. 4. IP地址 + User-Agent的组合
  20. """
  21. # 简化实现:使用时间窗口来判断是否为同一会话
  22. # 实际应用中应该从HTTP请求中获取session信息
  23. current_time = datetime.now()
  24. # 检查是否有近期的会话(比如30分钟内)
  25. for session_id, session_data in self.session_info.items():
  26. last_activity = session_data.get('last_activity', session_data['start_time'])
  27. if (current_time - last_activity).total_seconds() < 1800: # 30分钟内
  28. # 更新最后活动时间
  29. session_data['last_activity'] = current_time
  30. return session_id
  31. # 创建新会话
  32. new_session_id = str(uuid.uuid4())
  33. self.session_info[new_session_id] = {
  34. 'start_time': current_time,
  35. 'last_activity': current_time,
  36. 'conversations': []
  37. }
  38. return new_session_id
  39. def generate_id(self, question: str = None, session_id: str = None) -> str:
  40. """重载generate_id方法,关联会话和对话"""
  41. conversation_id = super().generate_id(question=question)
  42. # 确定会话ID
  43. if not session_id:
  44. session_id = self.create_or_get_session_id()
  45. # 记录对话开始时间
  46. conversation_start_time = datetime.now()
  47. self.conversation_start_times[conversation_id] = conversation_start_time
  48. # 建立对话与会话的关联
  49. self.conversation_to_session[conversation_id] = session_id
  50. self.session_info[session_id]['conversations'].append(conversation_id)
  51. self.session_info[session_id]['last_activity'] = conversation_start_time
  52. return conversation_id
  53. def set(self, id: str, field: str, value, session_id: str = None):
  54. """重载set方法,确保时间信息正确"""
  55. # 如果这是新对话,初始化时间信息
  56. if id not in self.conversation_start_times:
  57. if not session_id:
  58. session_id = self.create_or_get_session_id()
  59. conversation_start_time = datetime.now()
  60. self.conversation_start_times[id] = conversation_start_time
  61. self.conversation_to_session[id] = session_id
  62. self.session_info[session_id]['conversations'].append(id)
  63. self.session_info[session_id]['last_activity'] = conversation_start_time
  64. # 调用父类的set方法
  65. super().set(id=id, field=field, value=value)
  66. # 设置时间相关字段
  67. if field != 'conversation_start_time' and field != 'session_start_time':
  68. # 设置对话开始时间
  69. super().set(id=id, field='conversation_start_time',
  70. value=self.conversation_start_times[id])
  71. # 设置会话开始时间
  72. session_id = self.conversation_to_session.get(id)
  73. if session_id and session_id in self.session_info:
  74. super().set(id=id, field='session_start_time',
  75. value=self.session_info[session_id]['start_time'])
  76. super().set(id=id, field='session_id', value=session_id)
  77. def get_conversation_start_time(self, conversation_id: str) -> datetime:
  78. """获取对话开始时间"""
  79. return self.conversation_start_times.get(conversation_id)
  80. def get_session_start_time(self, conversation_id: str) -> datetime:
  81. """获取会话开始时间"""
  82. session_id = self.conversation_to_session.get(conversation_id)
  83. if session_id and session_id in self.session_info:
  84. return self.session_info[session_id]['start_time']
  85. return None
  86. def get_session_info(self, session_id: str = None, conversation_id: str = None):
  87. """获取会话信息"""
  88. if conversation_id:
  89. session_id = self.conversation_to_session.get(conversation_id)
  90. if session_id and session_id in self.session_info:
  91. session_data = self.session_info[session_id].copy()
  92. session_data['conversation_count'] = len(session_data['conversations'])
  93. if session_data['conversations']:
  94. # 计算会话持续时间
  95. duration = datetime.now() - session_data['start_time']
  96. session_data['session_duration_seconds'] = duration.total_seconds()
  97. session_data['session_duration_formatted'] = str(duration)
  98. return session_data
  99. return None
  100. def get_all_sessions(self):
  101. """获取所有会话信息"""
  102. result = {}
  103. for session_id, session_data in self.session_info.items():
  104. session_info = session_data.copy()
  105. session_info['conversation_count'] = len(session_data['conversations'])
  106. if session_data['conversations']:
  107. duration = datetime.now() - session_data['start_time']
  108. session_info['session_duration_seconds'] = duration.total_seconds()
  109. session_info['session_duration_formatted'] = str(duration)
  110. result[session_id] = session_info
  111. return result
  112. # 升级版:支持前端传递会话ID
  113. class WebSessionAwareMemoryCache(SessionAwareMemoryCache):
  114. """支持从前端获取会话ID的版本"""
  115. def __init__(self):
  116. super().__init__()
  117. self.browser_sessions = {} # browser_session_id -> our_session_id
  118. def register_browser_session(self, browser_session_id: str, user_info: dict = None):
  119. """注册浏览器会话"""
  120. if browser_session_id not in self.browser_sessions:
  121. our_session_id = str(uuid.uuid4())
  122. self.browser_sessions[browser_session_id] = our_session_id
  123. self.session_info[our_session_id] = {
  124. 'start_time': datetime.now(),
  125. 'last_activity': datetime.now(),
  126. 'conversations': [],
  127. 'browser_session_id': browser_session_id,
  128. 'user_info': user_info or {}
  129. }
  130. return self.browser_sessions[browser_session_id]
  131. def generate_id_with_browser_session(self, question: str = None, browser_session_id: str = None) -> str:
  132. """使用浏览器会话ID生成对话ID"""
  133. if browser_session_id:
  134. our_session_id = self.register_browser_session(browser_session_id)
  135. else:
  136. our_session_id = self.create_or_get_session_id()
  137. return super().generate_id(question=question, session_id=our_session_id)