|
@@ -0,0 +1,264 @@
|
|
|
+import redis
|
|
|
+import json
|
|
|
+import hashlib
|
|
|
+import time
|
|
|
+from typing import List, Optional, Dict, Any
|
|
|
+from datetime import datetime
|
|
|
+import app_config
|
|
|
+
|
|
|
+
|
|
|
+class EmbeddingCacheManager:
|
|
|
+ """Embedding向量缓存管理器"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ """初始化缓存管理器"""
|
|
|
+ self.redis_client = None
|
|
|
+ self.cache_enabled = app_config.ENABLE_EMBEDDING_CACHE
|
|
|
+
|
|
|
+ if self.cache_enabled:
|
|
|
+ try:
|
|
|
+ self.redis_client = redis.Redis(
|
|
|
+ host=app_config.REDIS_HOST,
|
|
|
+ port=app_config.REDIS_PORT,
|
|
|
+ db=app_config.REDIS_DB,
|
|
|
+ password=app_config.REDIS_PASSWORD,
|
|
|
+ decode_responses=True,
|
|
|
+ socket_connect_timeout=5,
|
|
|
+ socket_timeout=5
|
|
|
+ )
|
|
|
+ # 测试连接
|
|
|
+ self.redis_client.ping()
|
|
|
+ print(f"[DEBUG] Embedding缓存管理器初始化成功")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] Redis连接失败,embedding缓存将被禁用: {e}")
|
|
|
+ self.cache_enabled = False
|
|
|
+ self.redis_client = None
|
|
|
+
|
|
|
+ def is_available(self) -> bool:
|
|
|
+ """检查缓存是否可用"""
|
|
|
+ return self.cache_enabled and self.redis_client is not None
|
|
|
+
|
|
|
+ def _get_cache_key(self, question: str, model_info: Dict[str, str]) -> str:
|
|
|
+ """
|
|
|
+ 生成缓存键
|
|
|
+
|
|
|
+ Args:
|
|
|
+ question: 问题文本
|
|
|
+ model_info: 模型信息字典,包含model_name和embedding_dimension
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 缓存键字符串
|
|
|
+ """
|
|
|
+ # 使用问题的hash值避免键太长
|
|
|
+ question_hash = hashlib.sha256(question.encode('utf-8')).hexdigest()[:16]
|
|
|
+ model_name = model_info.get('model_name', 'unknown')
|
|
|
+ dimension = model_info.get('embedding_dimension', 'unknown')
|
|
|
+
|
|
|
+ return f"embedding_cache:{question_hash}:{model_name}:{dimension}"
|
|
|
+
|
|
|
+ def _get_model_info(self) -> Dict[str, str]:
|
|
|
+ """
|
|
|
+ 获取当前模型信息
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含model_name和embedding_dimension的字典
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ from common.utils import get_current_embedding_config
|
|
|
+ embedding_config = get_current_embedding_config()
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'model_name': embedding_config.get('model_name', 'unknown'),
|
|
|
+ 'embedding_dimension': str(embedding_config.get('embedding_dimension', 'unknown'))
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 获取模型信息失败: {e}")
|
|
|
+ return {'model_name': 'unknown', 'embedding_dimension': 'unknown'}
|
|
|
+
|
|
|
+ def get_cached_embedding(self, question: str) -> Optional[List[float]]:
|
|
|
+ """
|
|
|
+ 从缓存中获取embedding向量
|
|
|
+
|
|
|
+ Args:
|
|
|
+ question: 问题文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 如果缓存命中返回向量列表,否则返回None
|
|
|
+ """
|
|
|
+ if not self.is_available():
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ model_info = self._get_model_info()
|
|
|
+ cache_key = self._get_cache_key(question, model_info)
|
|
|
+
|
|
|
+ cached_data = self.redis_client.get(cache_key)
|
|
|
+ if cached_data:
|
|
|
+ data = json.loads(cached_data)
|
|
|
+ vector = data.get('vector')
|
|
|
+ if vector:
|
|
|
+ print(f"[DEBUG] ✓ Embedding缓存命中: {question[:50]}...")
|
|
|
+ return vector
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 获取embedding缓存失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def cache_embedding(self, question: str, vector: List[float]) -> bool:
|
|
|
+ """
|
|
|
+ 将embedding向量保存到缓存
|
|
|
+
|
|
|
+ Args:
|
|
|
+ question: 问题文本
|
|
|
+ vector: embedding向量
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 成功返回True,失败返回False
|
|
|
+ """
|
|
|
+ if not self.is_available() or not vector:
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ model_info = self._get_model_info()
|
|
|
+ cache_key = self._get_cache_key(question, model_info)
|
|
|
+
|
|
|
+ cache_data = {
|
|
|
+ "question": question,
|
|
|
+ "vector": vector,
|
|
|
+ "model_name": model_info['model_name'],
|
|
|
+ "dimension": len(vector),
|
|
|
+ "created_at": datetime.now().isoformat(),
|
|
|
+ "version": "1.0"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 设置缓存,使用配置的TTL
|
|
|
+ ttl = app_config.EMBEDDING_CACHE_TTL
|
|
|
+ self.redis_client.setex(
|
|
|
+ cache_key,
|
|
|
+ ttl,
|
|
|
+ json.dumps(cache_data, ensure_ascii=False)
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"[DEBUG] ✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
|
|
|
+
|
|
|
+ # 检查缓存大小并清理
|
|
|
+ self._cleanup_if_needed()
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 缓存embedding失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def _cleanup_if_needed(self):
|
|
|
+ """
|
|
|
+ 如果缓存大小超过限制,清理最旧的缓存
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ max_size = app_config.EMBEDDING_CACHE_MAX_SIZE
|
|
|
+ pattern = "embedding_cache:*"
|
|
|
+
|
|
|
+ # 获取所有embedding缓存键
|
|
|
+ keys = self.redis_client.keys(pattern)
|
|
|
+
|
|
|
+ if len(keys) > max_size:
|
|
|
+ # 需要清理,获取键的TTL信息并按剩余时间排序
|
|
|
+ keys_with_ttl = []
|
|
|
+ for key in keys:
|
|
|
+ ttl = self.redis_client.ttl(key)
|
|
|
+ if ttl > 0: # 只考虑有TTL的键
|
|
|
+ keys_with_ttl.append((key, ttl))
|
|
|
+
|
|
|
+ # 按TTL升序排序(剩余时间少的在前面)
|
|
|
+ keys_with_ttl.sort(key=lambda x: x[1])
|
|
|
+
|
|
|
+ # 删除超出限制的旧键
|
|
|
+ cleanup_count = len(keys) - max_size
|
|
|
+ keys_to_delete = [key for key, _ in keys_with_ttl[:cleanup_count]]
|
|
|
+
|
|
|
+ if keys_to_delete:
|
|
|
+ self.redis_client.delete(*keys_to_delete)
|
|
|
+ print(f"[DEBUG] 清理了 {len(keys_to_delete)} 个旧的embedding缓存")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 清理embedding缓存失败: {e}")
|
|
|
+
|
|
|
+ def get_cache_stats(self) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 获取缓存统计信息
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含缓存统计信息的字典
|
|
|
+ """
|
|
|
+ stats = {
|
|
|
+ "enabled": self.cache_enabled,
|
|
|
+ "available": self.is_available(),
|
|
|
+ "total_count": 0,
|
|
|
+ "memory_usage_mb": 0
|
|
|
+ }
|
|
|
+
|
|
|
+ if not self.is_available():
|
|
|
+ return stats
|
|
|
+
|
|
|
+ try:
|
|
|
+ pattern = "embedding_cache:*"
|
|
|
+ keys = self.redis_client.keys(pattern)
|
|
|
+ stats["total_count"] = len(keys)
|
|
|
+
|
|
|
+ # 估算内存使用量(粗略计算)
|
|
|
+ if keys:
|
|
|
+ sample_key = keys[0]
|
|
|
+ sample_data = self.redis_client.get(sample_key)
|
|
|
+ if sample_data:
|
|
|
+ avg_size_bytes = len(sample_data.encode('utf-8'))
|
|
|
+ total_size_bytes = avg_size_bytes * len(keys)
|
|
|
+ stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 获取缓存统计失败: {e}")
|
|
|
+
|
|
|
+ return stats
|
|
|
+
|
|
|
+ def clear_all_cache(self) -> bool:
|
|
|
+ """
|
|
|
+ 清空所有embedding缓存
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 成功返回True,失败返回False
|
|
|
+ """
|
|
|
+ if not self.is_available():
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ pattern = "embedding_cache:*"
|
|
|
+ keys = self.redis_client.keys(pattern)
|
|
|
+
|
|
|
+ if keys:
|
|
|
+ self.redis_client.delete(*keys)
|
|
|
+ print(f"[DEBUG] 已清空所有embedding缓存 ({len(keys)} 条)")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f"[DEBUG] 没有embedding缓存需要清空")
|
|
|
+ return True
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"[WARNING] 清空embedding缓存失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+# 全局实例
|
|
|
+_embedding_cache_manager = None
|
|
|
+
|
|
|
+def get_embedding_cache_manager() -> EmbeddingCacheManager:
|
|
|
+ """
|
|
|
+ 获取全局embedding缓存管理器实例
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ EmbeddingCacheManager实例
|
|
|
+ """
|
|
|
+ global _embedding_cache_manager
|
|
|
+ if _embedding_cache_manager is None:
|
|
|
+ _embedding_cache_manager = EmbeddingCacheManager()
|
|
|
+ return _embedding_cache_manager
|