embedding_cache_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import redis
  2. import json
  3. import hashlib
  4. import time
  5. from typing import List, Optional, Dict, Any
  6. from datetime import datetime
  7. import app_config
  8. from core.logging import get_app_logger
  9. class EmbeddingCacheManager:
  10. """Embedding向量缓存管理器"""
  11. def __init__(self):
  12. """初始化缓存管理器"""
  13. self.logger = get_app_logger("EmbeddingCacheManager")
  14. self.redis_client = None
  15. self.cache_enabled = app_config.ENABLE_EMBEDDING_CACHE
  16. if self.cache_enabled:
  17. try:
  18. self.redis_client = redis.Redis(
  19. host=app_config.REDIS_HOST,
  20. port=app_config.REDIS_PORT,
  21. db=app_config.REDIS_DB,
  22. password=app_config.REDIS_PASSWORD,
  23. decode_responses=True,
  24. socket_connect_timeout=5,
  25. socket_timeout=5
  26. )
  27. # 测试连接
  28. self.redis_client.ping()
  29. self.logger.debug("Embedding缓存管理器初始化成功")
  30. except Exception as e:
  31. self.logger.warning(f"Redis连接失败,embedding缓存将被禁用: {e}")
  32. self.cache_enabled = False
  33. self.redis_client = None
  34. def is_available(self) -> bool:
  35. """检查缓存是否可用"""
  36. return self.cache_enabled and self.redis_client is not None
  37. def _get_cache_key(self, question: str, model_info: Dict[str, str]) -> str:
  38. """
  39. 生成缓存键
  40. Args:
  41. question: 问题文本
  42. model_info: 模型信息字典,包含model_name和embedding_dimension
  43. Returns:
  44. 缓存键字符串
  45. """
  46. # 使用问题的hash值避免键太长
  47. question_hash = hashlib.sha256(question.encode('utf-8')).hexdigest()[:16]
  48. model_name = model_info.get('model_name', 'unknown')
  49. dimension = model_info.get('embedding_dimension', 'unknown')
  50. return f"embedding_cache:{question_hash}:{model_name}:{dimension}"
  51. def _get_model_info(self) -> Dict[str, str]:
  52. """
  53. 获取当前模型信息
  54. Returns:
  55. 包含model_name和embedding_dimension的字典
  56. """
  57. try:
  58. from common.utils import get_current_embedding_config
  59. embedding_config = get_current_embedding_config()
  60. return {
  61. 'model_name': embedding_config.get('model_name', 'unknown'),
  62. 'embedding_dimension': str(embedding_config.get('embedding_dimension', 'unknown'))
  63. }
  64. except Exception as e:
  65. self.logger.warning(f"获取模型信息失败: {e}")
  66. return {'model_name': 'unknown', 'embedding_dimension': 'unknown'}
  67. def get_cached_embedding(self, question: str) -> Optional[List[float]]:
  68. """
  69. 从缓存中获取embedding向量
  70. Args:
  71. question: 问题文本
  72. Returns:
  73. 如果缓存命中返回向量列表,否则返回None
  74. """
  75. if not self.is_available():
  76. return None
  77. try:
  78. model_info = self._get_model_info()
  79. cache_key = self._get_cache_key(question, model_info)
  80. cached_data = self.redis_client.get(cache_key)
  81. if cached_data:
  82. data = json.loads(cached_data)
  83. vector = data.get('vector')
  84. if vector:
  85. self.logger.debug(f"✓ Embedding缓存命中: {question[:50]}...")
  86. return vector
  87. return None
  88. except Exception as e:
  89. self.logger.warning(f"获取embedding缓存失败: {e}")
  90. return None
  91. def cache_embedding(self, question: str, vector: List[float]) -> bool:
  92. """
  93. 将embedding向量保存到缓存
  94. Args:
  95. question: 问题文本
  96. vector: embedding向量
  97. Returns:
  98. 成功返回True,失败返回False
  99. """
  100. if not self.is_available() or not vector:
  101. return False
  102. try:
  103. model_info = self._get_model_info()
  104. cache_key = self._get_cache_key(question, model_info)
  105. cache_data = {
  106. "question": question,
  107. "vector": vector,
  108. "model_name": model_info['model_name'],
  109. "dimension": len(vector),
  110. "created_at": datetime.now().isoformat(),
  111. "version": "1.0"
  112. }
  113. # 设置缓存,使用配置的TTL
  114. ttl = app_config.EMBEDDING_CACHE_TTL
  115. self.redis_client.setex(
  116. cache_key,
  117. ttl,
  118. json.dumps(cache_data, ensure_ascii=False)
  119. )
  120. self.logger.debug(f"✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
  121. # 检查缓存大小并清理
  122. self._cleanup_if_needed()
  123. return True
  124. except Exception as e:
  125. self.logger.warning(f"缓存embedding失败: {e}")
  126. return False
  127. def _cleanup_if_needed(self):
  128. """
  129. 如果缓存大小超过限制,清理最旧的缓存
  130. """
  131. try:
  132. max_size = app_config.EMBEDDING_CACHE_MAX_SIZE
  133. pattern = "embedding_cache:*"
  134. # 获取所有embedding缓存键
  135. keys = self.redis_client.keys(pattern)
  136. if len(keys) > max_size:
  137. # 需要清理,获取键的TTL信息并按剩余时间排序
  138. keys_with_ttl = []
  139. for key in keys:
  140. ttl = self.redis_client.ttl(key)
  141. if ttl > 0: # 只考虑有TTL的键
  142. keys_with_ttl.append((key, ttl))
  143. # 按TTL升序排序(剩余时间少的在前面)
  144. keys_with_ttl.sort(key=lambda x: x[1])
  145. # 删除超出限制的旧键
  146. cleanup_count = len(keys) - max_size
  147. keys_to_delete = [key for key, _ in keys_with_ttl[:cleanup_count]]
  148. if keys_to_delete:
  149. self.redis_client.delete(*keys_to_delete)
  150. self.logger.debug(f"清理了 {len(keys_to_delete)} 个旧的embedding缓存")
  151. except Exception as e:
  152. self.logger.warning(f"清理embedding缓存失败: {e}")
  153. def get_cache_stats(self) -> Dict[str, Any]:
  154. """
  155. 获取缓存统计信息
  156. Returns:
  157. 包含缓存统计信息的字典
  158. """
  159. stats = {
  160. "enabled": self.cache_enabled,
  161. "available": self.is_available(),
  162. "total_count": 0,
  163. "memory_usage_mb": 0
  164. }
  165. if not self.is_available():
  166. return stats
  167. try:
  168. pattern = "embedding_cache:*"
  169. keys = self.redis_client.keys(pattern)
  170. stats["total_count"] = len(keys)
  171. # 估算内存使用量(粗略计算)
  172. if keys:
  173. sample_key = keys[0]
  174. sample_data = self.redis_client.get(sample_key)
  175. if sample_data:
  176. avg_size_bytes = len(sample_data.encode('utf-8'))
  177. total_size_bytes = avg_size_bytes * len(keys)
  178. stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
  179. except Exception as e:
  180. self.logger.warning(f"获取缓存统计失败: {e}")
  181. return stats
  182. def clear_all_cache(self) -> bool:
  183. """
  184. 清空所有embedding缓存
  185. Returns:
  186. 成功返回True,失败返回False
  187. """
  188. if not self.is_available():
  189. return False
  190. try:
  191. pattern = "embedding_cache:*"
  192. keys = self.redis_client.keys(pattern)
  193. if keys:
  194. self.redis_client.delete(*keys)
  195. self.logger.debug(f"已清空所有embedding缓存 ({len(keys)} 条)")
  196. return True
  197. else:
  198. self.logger.debug("没有embedding缓存需要清空")
  199. return True
  200. except Exception as e:
  201. self.logger.warning(f"清空embedding缓存失败: {e}")
  202. return False
  203. # 全局实例
  204. _embedding_cache_manager = None
  205. def get_embedding_cache_manager() -> EmbeddingCacheManager:
  206. """
  207. 获取全局embedding缓存管理器实例
  208. Returns:
  209. EmbeddingCacheManager实例
  210. """
  211. global _embedding_cache_manager
  212. if _embedding_cache_manager is None:
  213. _embedding_cache_manager = EmbeddingCacheManager()
  214. return _embedding_cache_manager