embedding_cache_manager.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import redis
  2. import json
  3. import hashlib
  4. import time
  5. from typing import List, Optional, Dict, Any
  6. from datetime import datetime
  7. import app_config
  8. class EmbeddingCacheManager:
  9. """Embedding向量缓存管理器"""
  10. def __init__(self):
  11. """初始化缓存管理器"""
  12. self.redis_client = None
  13. self.cache_enabled = app_config.ENABLE_EMBEDDING_CACHE
  14. if self.cache_enabled:
  15. try:
  16. self.redis_client = redis.Redis(
  17. host=app_config.REDIS_HOST,
  18. port=app_config.REDIS_PORT,
  19. db=app_config.REDIS_DB,
  20. password=app_config.REDIS_PASSWORD,
  21. decode_responses=True,
  22. socket_connect_timeout=5,
  23. socket_timeout=5
  24. )
  25. # 测试连接
  26. self.redis_client.ping()
  27. print(f"[DEBUG] Embedding缓存管理器初始化成功")
  28. except Exception as e:
  29. print(f"[WARNING] Redis连接失败,embedding缓存将被禁用: {e}")
  30. self.cache_enabled = False
  31. self.redis_client = None
  32. def is_available(self) -> bool:
  33. """检查缓存是否可用"""
  34. return self.cache_enabled and self.redis_client is not None
  35. def _get_cache_key(self, question: str, model_info: Dict[str, str]) -> str:
  36. """
  37. 生成缓存键
  38. Args:
  39. question: 问题文本
  40. model_info: 模型信息字典,包含model_name和embedding_dimension
  41. Returns:
  42. 缓存键字符串
  43. """
  44. # 使用问题的hash值避免键太长
  45. question_hash = hashlib.sha256(question.encode('utf-8')).hexdigest()[:16]
  46. model_name = model_info.get('model_name', 'unknown')
  47. dimension = model_info.get('embedding_dimension', 'unknown')
  48. return f"embedding_cache:{question_hash}:{model_name}:{dimension}"
  49. def _get_model_info(self) -> Dict[str, str]:
  50. """
  51. 获取当前模型信息
  52. Returns:
  53. 包含model_name和embedding_dimension的字典
  54. """
  55. try:
  56. from common.utils import get_current_embedding_config
  57. embedding_config = get_current_embedding_config()
  58. return {
  59. 'model_name': embedding_config.get('model_name', 'unknown'),
  60. 'embedding_dimension': str(embedding_config.get('embedding_dimension', 'unknown'))
  61. }
  62. except Exception as e:
  63. print(f"[WARNING] 获取模型信息失败: {e}")
  64. return {'model_name': 'unknown', 'embedding_dimension': 'unknown'}
  65. def get_cached_embedding(self, question: str) -> Optional[List[float]]:
  66. """
  67. 从缓存中获取embedding向量
  68. Args:
  69. question: 问题文本
  70. Returns:
  71. 如果缓存命中返回向量列表,否则返回None
  72. """
  73. if not self.is_available():
  74. return None
  75. try:
  76. model_info = self._get_model_info()
  77. cache_key = self._get_cache_key(question, model_info)
  78. cached_data = self.redis_client.get(cache_key)
  79. if cached_data:
  80. data = json.loads(cached_data)
  81. vector = data.get('vector')
  82. if vector:
  83. print(f"[DEBUG] ✓ Embedding缓存命中: {question[:50]}...")
  84. return vector
  85. return None
  86. except Exception as e:
  87. print(f"[WARNING] 获取embedding缓存失败: {e}")
  88. return None
  89. def cache_embedding(self, question: str, vector: List[float]) -> bool:
  90. """
  91. 将embedding向量保存到缓存
  92. Args:
  93. question: 问题文本
  94. vector: embedding向量
  95. Returns:
  96. 成功返回True,失败返回False
  97. """
  98. if not self.is_available() or not vector:
  99. return False
  100. try:
  101. model_info = self._get_model_info()
  102. cache_key = self._get_cache_key(question, model_info)
  103. cache_data = {
  104. "question": question,
  105. "vector": vector,
  106. "model_name": model_info['model_name'],
  107. "dimension": len(vector),
  108. "created_at": datetime.now().isoformat(),
  109. "version": "1.0"
  110. }
  111. # 设置缓存,使用配置的TTL
  112. ttl = app_config.EMBEDDING_CACHE_TTL
  113. self.redis_client.setex(
  114. cache_key,
  115. ttl,
  116. json.dumps(cache_data, ensure_ascii=False)
  117. )
  118. print(f"[DEBUG] ✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
  119. # 检查缓存大小并清理
  120. self._cleanup_if_needed()
  121. return True
  122. except Exception as e:
  123. print(f"[WARNING] 缓存embedding失败: {e}")
  124. return False
  125. def _cleanup_if_needed(self):
  126. """
  127. 如果缓存大小超过限制,清理最旧的缓存
  128. """
  129. try:
  130. max_size = app_config.EMBEDDING_CACHE_MAX_SIZE
  131. pattern = "embedding_cache:*"
  132. # 获取所有embedding缓存键
  133. keys = self.redis_client.keys(pattern)
  134. if len(keys) > max_size:
  135. # 需要清理,获取键的TTL信息并按剩余时间排序
  136. keys_with_ttl = []
  137. for key in keys:
  138. ttl = self.redis_client.ttl(key)
  139. if ttl > 0: # 只考虑有TTL的键
  140. keys_with_ttl.append((key, ttl))
  141. # 按TTL升序排序(剩余时间少的在前面)
  142. keys_with_ttl.sort(key=lambda x: x[1])
  143. # 删除超出限制的旧键
  144. cleanup_count = len(keys) - max_size
  145. keys_to_delete = [key for key, _ in keys_with_ttl[:cleanup_count]]
  146. if keys_to_delete:
  147. self.redis_client.delete(*keys_to_delete)
  148. print(f"[DEBUG] 清理了 {len(keys_to_delete)} 个旧的embedding缓存")
  149. except Exception as e:
  150. print(f"[WARNING] 清理embedding缓存失败: {e}")
  151. def get_cache_stats(self) -> Dict[str, Any]:
  152. """
  153. 获取缓存统计信息
  154. Returns:
  155. 包含缓存统计信息的字典
  156. """
  157. stats = {
  158. "enabled": self.cache_enabled,
  159. "available": self.is_available(),
  160. "total_count": 0,
  161. "memory_usage_mb": 0
  162. }
  163. if not self.is_available():
  164. return stats
  165. try:
  166. pattern = "embedding_cache:*"
  167. keys = self.redis_client.keys(pattern)
  168. stats["total_count"] = len(keys)
  169. # 估算内存使用量(粗略计算)
  170. if keys:
  171. sample_key = keys[0]
  172. sample_data = self.redis_client.get(sample_key)
  173. if sample_data:
  174. avg_size_bytes = len(sample_data.encode('utf-8'))
  175. total_size_bytes = avg_size_bytes * len(keys)
  176. stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
  177. except Exception as e:
  178. print(f"[WARNING] 获取缓存统计失败: {e}")
  179. return stats
  180. def clear_all_cache(self) -> bool:
  181. """
  182. 清空所有embedding缓存
  183. Returns:
  184. 成功返回True,失败返回False
  185. """
  186. if not self.is_available():
  187. return False
  188. try:
  189. pattern = "embedding_cache:*"
  190. keys = self.redis_client.keys(pattern)
  191. if keys:
  192. self.redis_client.delete(*keys)
  193. print(f"[DEBUG] 已清空所有embedding缓存 ({len(keys)} 条)")
  194. return True
  195. else:
  196. print(f"[DEBUG] 没有embedding缓存需要清空")
  197. return True
  198. except Exception as e:
  199. print(f"[WARNING] 清空embedding缓存失败: {e}")
  200. return False
  201. # 全局实例
  202. _embedding_cache_manager = None
  203. def get_embedding_cache_manager() -> EmbeddingCacheManager:
  204. """
  205. 获取全局embedding缓存管理器实例
  206. Returns:
  207. EmbeddingCacheManager实例
  208. """
  209. global _embedding_cache_manager
  210. if _embedding_cache_manager is None:
  211. _embedding_cache_manager = EmbeddingCacheManager()
  212. return _embedding_cache_manager