Bladeren bron

增加了对问题向量化缓存到Redis的模块,缓存问题的向量值。

wangxq 1 week geleden
bovenliggende
commit
56ca6e03a9
5 gewijzigde bestanden met toevoegingen van 468 en 18 verwijderingen
  1. 3 0
      .vscode/settings.json
  2. 6 1
      app_config.py
  3. 92 1
      citu_app.py
  4. 264 0
      common/embedding_cache_manager.py
  5. 103 16
      custompgvector/pgvector.py

+ 3 - 0
.vscode/settings.json

@@ -0,0 +1,3 @@
+{
+    "terminal.integrated.scrollback": 5000
+}

+ 6 - 1
app_config.py

@@ -183,8 +183,13 @@ REDIS_PASSWORD = None
 # 缓存开关配置
 ENABLE_CONVERSATION_CONTEXT = True      # 是否启用对话上下文
 ENABLE_QUESTION_ANSWER_CACHE = True     # 是否启用问答结果缓存
+ENABLE_EMBEDDING_CACHE = True           # 是否启用embedding向量缓存
 
 # TTL配置(单位:秒)
 CONVERSATION_TTL = 7 * 24 * 3600        # 对话保存7天
 USER_CONVERSATIONS_TTL = 7 * 24 * 3600  # 用户对话列表保存7天(所有用户统一)
-QUESTION_ANSWER_TTL = 24 * 3600         # 问答结果缓存24小时
+QUESTION_ANSWER_TTL = 24 * 3600         # 问答结果缓存24小时
+EMBEDDING_CACHE_TTL = 30 * 24 * 3600    # embedding向量缓存30天
+
+# Embedding缓存管理配置
+EMBEDDING_CACHE_MAX_SIZE = 5000        # 最大缓存问题数量

+ 92 - 1
citu_app.py

@@ -21,7 +21,8 @@ from common.result import (  # 统一导入所有需要的响应函数
 from app_config import (  # 添加Redis相关配置导入
     USER_MAX_CONVERSATIONS,
     CONVERSATION_CONTEXT_COUNT,
-    DEFAULT_ANONYMOUS_USER
+    DEFAULT_ANONYMOUS_USER,
+    ENABLE_QUESTION_ANSWER_CACHE
 )
 
 # 设置默认的最大返回行数
@@ -1704,6 +1705,96 @@ def get_user_conversations_with_messages(user_id: str):
         )), 500
 
 
+# ==================== Embedding缓存管理接口 ====================
+
+@app.flask_app.route('/api/v0/embedding_cache_stats', methods=['GET'])
+def embedding_cache_stats():
+    """获取embedding缓存统计信息"""
+    try:
+        from common.embedding_cache_manager import get_embedding_cache_manager
+        
+        cache_manager = get_embedding_cache_manager()
+        stats = cache_manager.get_cache_stats()
+        
+        return jsonify(success_response(
+            response_text="获取embedding缓存统计成功",
+            data=stats
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] 获取embedding缓存统计失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="获取embedding缓存统计失败,请稍后重试"
+        )), 500
+
+@app.flask_app.route('/api/v0/embedding_cache_cleanup', methods=['POST'])
+def embedding_cache_cleanup():
+    """清空所有embedding缓存"""
+    try:
+        from common.embedding_cache_manager import get_embedding_cache_manager
+        
+        cache_manager = get_embedding_cache_manager()
+        
+        if not cache_manager.is_available():
+            return jsonify(internal_error_response(
+                response_text="Embedding缓存功能未启用或不可用"
+            )), 400
+        
+        success = cache_manager.clear_all_cache()
+        
+        if success:
+            return jsonify(success_response(
+                response_text="所有embedding缓存已清空",
+                data={"cleared": True}
+            ))
+        else:
+            return jsonify(internal_error_response(
+                response_text="清空embedding缓存失败"
+            )), 500
+        
+    except Exception as e:
+        print(f"[ERROR] 清空embedding缓存失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="清空embedding缓存失败,请稍后重试"
+        )), 500
+
+@app.flask_app.route('/api/v0/cache_overview_full', methods=['GET'])
+def cache_overview_full():
+    """获取所有缓存系统的综合概览"""
+    try:
+        from common.embedding_cache_manager import get_embedding_cache_manager
+        from common.vanna_instance import get_vanna_instance
+        from common.session_aware_cache import get_cache
+        
+        # 获取现有的缓存统计
+        vanna_cache = get_vanna_instance()
+        cache = get_cache()
+        
+        cache_overview = {
+            "conversation_aware_cache": {
+                "enabled": True,
+                "total_items": len(cache.cache),
+                "sessions": list(cache.cache.keys()) if hasattr(cache, 'cache') else []
+            },
+            "question_answer_cache": {
+                "enabled": ENABLE_QUESTION_ANSWER_CACHE,
+                "stats": redis_conversation_manager.get_stats() if redis_conversation_manager.is_available() else None
+            },
+            "embedding_cache": get_embedding_cache_manager().get_cache_stats()
+        }
+        
+        return jsonify(success_response(
+            response_text="获取综合缓存概览成功",
+            data=cache_overview
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] 获取综合缓存概览失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="获取缓存概览失败,请稍后重试"
+        )), 500
+
+
 # 前端JavaScript示例 - 如何维持会话
 """
 // 前端需要维护一个会话ID

+ 264 - 0
common/embedding_cache_manager.py

@@ -0,0 +1,264 @@
+import redis
+import json
+import hashlib
+import time
+from typing import List, Optional, Dict, Any
+from datetime import datetime
+import app_config
+
+
+class EmbeddingCacheManager:
+    """Embedding向量缓存管理器"""
+    
+    def __init__(self):
+        """初始化缓存管理器"""
+        self.redis_client = None
+        self.cache_enabled = app_config.ENABLE_EMBEDDING_CACHE
+        
+        if self.cache_enabled:
+            try:
+                self.redis_client = redis.Redis(
+                    host=app_config.REDIS_HOST,
+                    port=app_config.REDIS_PORT,
+                    db=app_config.REDIS_DB,
+                    password=app_config.REDIS_PASSWORD,
+                    decode_responses=True,
+                    socket_connect_timeout=5,
+                    socket_timeout=5
+                )
+                # 测试连接
+                self.redis_client.ping()
+                print(f"[DEBUG] Embedding缓存管理器初始化成功")
+            except Exception as e:
+                print(f"[WARNING] Redis连接失败,embedding缓存将被禁用: {e}")
+                self.cache_enabled = False
+                self.redis_client = None
+    
+    def is_available(self) -> bool:
+        """检查缓存是否可用"""
+        return self.cache_enabled and self.redis_client is not None
+    
+    def _get_cache_key(self, question: str, model_info: Dict[str, str]) -> str:
+        """
+        生成缓存键
+        
+        Args:
+            question: 问题文本
+            model_info: 模型信息字典,包含model_name和embedding_dimension
+            
+        Returns:
+            缓存键字符串
+        """
+        # 使用问题的hash值避免键太长
+        question_hash = hashlib.sha256(question.encode('utf-8')).hexdigest()[:16]
+        model_name = model_info.get('model_name', 'unknown')
+        dimension = model_info.get('embedding_dimension', 'unknown')
+        
+        return f"embedding_cache:{question_hash}:{model_name}:{dimension}"
+    
+    def _get_model_info(self) -> Dict[str, str]:
+        """
+        获取当前模型信息
+        
+        Returns:
+            包含model_name和embedding_dimension的字典
+        """
+        try:
+            from common.utils import get_current_embedding_config
+            embedding_config = get_current_embedding_config()
+            
+            return {
+                'model_name': embedding_config.get('model_name', 'unknown'),
+                'embedding_dimension': str(embedding_config.get('embedding_dimension', 'unknown'))
+            }
+        except Exception as e:
+            print(f"[WARNING] 获取模型信息失败: {e}")
+            return {'model_name': 'unknown', 'embedding_dimension': 'unknown'}
+    
+    def get_cached_embedding(self, question: str) -> Optional[List[float]]:
+        """
+        从缓存中获取embedding向量
+        
+        Args:
+            question: 问题文本
+            
+        Returns:
+            如果缓存命中返回向量列表,否则返回None
+        """
+        if not self.is_available():
+            return None
+        
+        try:
+            model_info = self._get_model_info()
+            cache_key = self._get_cache_key(question, model_info)
+            
+            cached_data = self.redis_client.get(cache_key)
+            if cached_data:
+                data = json.loads(cached_data)
+                vector = data.get('vector')
+                if vector:
+                    print(f"[DEBUG] ✓ Embedding缓存命中: {question[:50]}...")
+                    return vector
+            
+            return None
+            
+        except Exception as e:
+            print(f"[WARNING] 获取embedding缓存失败: {e}")
+            return None
+    
+    def cache_embedding(self, question: str, vector: List[float]) -> bool:
+        """
+        将embedding向量保存到缓存
+        
+        Args:
+            question: 问题文本
+            vector: embedding向量
+            
+        Returns:
+            成功返回True,失败返回False
+        """
+        if not self.is_available() or not vector:
+            return False
+        
+        try:
+            model_info = self._get_model_info()
+            cache_key = self._get_cache_key(question, model_info)
+            
+            cache_data = {
+                "question": question,
+                "vector": vector,
+                "model_name": model_info['model_name'],
+                "dimension": len(vector),
+                "created_at": datetime.now().isoformat(),
+                "version": "1.0"
+            }
+            
+            # 设置缓存,使用配置的TTL
+            ttl = app_config.EMBEDDING_CACHE_TTL
+            self.redis_client.setex(
+                cache_key,
+                ttl,
+                json.dumps(cache_data, ensure_ascii=False)
+            )
+            
+            print(f"[DEBUG] ✓ Embedding向量已缓存: {question[:50]}... (维度: {len(vector)})")
+            
+            # 检查缓存大小并清理
+            self._cleanup_if_needed()
+            
+            return True
+            
+        except Exception as e:
+            print(f"[WARNING] 缓存embedding失败: {e}")
+            return False
+    
+    def _cleanup_if_needed(self):
+        """
+        如果缓存大小超过限制,清理最旧的缓存
+        """
+        try:
+            max_size = app_config.EMBEDDING_CACHE_MAX_SIZE
+            pattern = "embedding_cache:*"
+            
+            # 获取所有embedding缓存键
+            keys = self.redis_client.keys(pattern)
+            
+            if len(keys) > max_size:
+                # 需要清理,获取键的TTL信息并按剩余时间排序
+                keys_with_ttl = []
+                for key in keys:
+                    ttl = self.redis_client.ttl(key)
+                    if ttl > 0:  # 只考虑有TTL的键
+                        keys_with_ttl.append((key, ttl))
+                
+                # 按TTL升序排序(剩余时间少的在前面)
+                keys_with_ttl.sort(key=lambda x: x[1])
+                
+                # 删除超出限制的旧键
+                cleanup_count = len(keys) - max_size
+                keys_to_delete = [key for key, _ in keys_with_ttl[:cleanup_count]]
+                
+                if keys_to_delete:
+                    self.redis_client.delete(*keys_to_delete)
+                    print(f"[DEBUG] 清理了 {len(keys_to_delete)} 个旧的embedding缓存")
+                    
+        except Exception as e:
+            print(f"[WARNING] 清理embedding缓存失败: {e}")
+    
+    def get_cache_stats(self) -> Dict[str, Any]:
+        """
+        获取缓存统计信息
+        
+        Returns:
+            包含缓存统计信息的字典
+        """
+        stats = {
+            "enabled": self.cache_enabled,
+            "available": self.is_available(),
+            "total_count": 0,
+            "memory_usage_mb": 0
+        }
+        
+        if not self.is_available():
+            return stats
+        
+        try:
+            pattern = "embedding_cache:*"
+            keys = self.redis_client.keys(pattern)
+            stats["total_count"] = len(keys)
+            
+            # 估算内存使用量(粗略计算)
+            if keys:
+                sample_key = keys[0]
+                sample_data = self.redis_client.get(sample_key)
+                if sample_data:
+                    avg_size_bytes = len(sample_data.encode('utf-8'))
+                    total_size_bytes = avg_size_bytes * len(keys)
+                    stats["memory_usage_mb"] = round(total_size_bytes / (1024 * 1024), 2)
+            
+        except Exception as e:
+            print(f"[WARNING] 获取缓存统计失败: {e}")
+        
+        return stats
+    
+    def clear_all_cache(self) -> bool:
+        """
+        清空所有embedding缓存
+        
+        Returns:
+            成功返回True,失败返回False
+        """
+        if not self.is_available():
+            return False
+        
+        try:
+            pattern = "embedding_cache:*"
+            keys = self.redis_client.keys(pattern)
+            
+            if keys:
+                self.redis_client.delete(*keys)
+                print(f"[DEBUG] 已清空所有embedding缓存 ({len(keys)} 条)")
+                return True
+            else:
+                print(f"[DEBUG] 没有embedding缓存需要清空")
+                return True
+                
+        except Exception as e:
+            print(f"[WARNING] 清空embedding缓存失败: {e}")
+            return False
+
+
+# 全局实例
+_embedding_cache_manager = None
+
+def get_embedding_cache_manager() -> EmbeddingCacheManager:
+    """
+    获取全局embedding缓存管理器实例
+    
+    Returns:
+        EmbeddingCacheManager实例
+    """
+    global _embedding_cache_manager
+    if _embedding_cache_manager is None:
+        _embedding_cache_manager = EmbeddingCacheManager()
+    return _embedding_cache_manager 

+ 103 - 16
custompgvector/pgvector.py

@@ -12,6 +12,9 @@ from vanna.exceptions import ValidationError
 from vanna.base import VannaBase
 from vanna.types import TrainingPlan, TrainingPlanItem
 
+# 导入embedding缓存管理器
+from common.embedding_cache_manager import get_embedding_cache_manager
+
 
 class PG_VectorStore(VannaBase):
     def __init__(self, config=None):
@@ -108,10 +111,31 @@ class PG_VectorStore(VannaBase):
 
     # 在原来的基础之上,增加相似度的值。
     def get_similar_question_sql(self, question: str) -> list:
-        docs_with_scores = self.sql_collection.similarity_search_with_score(
-            query=question,
-            k=self.n_results
-        )
+        # 尝试使用embedding缓存
+        embedding_cache = get_embedding_cache_manager()
+        cached_embedding = embedding_cache.get_cached_embedding(question)
+        
+        if cached_embedding is not None:
+            # 使用缓存的embedding进行向量搜索
+            docs_with_scores = self.sql_collection.similarity_search_with_score_by_vector(
+                embedding=cached_embedding,
+                k=self.n_results
+            )
+        else:
+            # 执行常规的向量搜索(会自动生成embedding)
+            docs_with_scores = self.sql_collection.similarity_search_with_score(
+                query=question,
+                k=self.n_results
+            )
+            
+            # 获取刚生成的embedding并缓存
+            try:
+                # 通过embedding_function生成向量并缓存
+                generated_embedding = self.embedding_function.generate_embedding(question)
+                if generated_embedding:
+                    embedding_cache.cache_embedding(question, generated_embedding)
+            except Exception as e:
+                print(f"[WARNING] 缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -138,10 +162,31 @@ class PG_VectorStore(VannaBase):
         return filtered_results
 
     def get_related_ddl(self, question: str, **kwargs) -> list:
-        docs_with_scores = self.ddl_collection.similarity_search_with_score(
-            query=question,
-            k=self.n_results
-        )
+        # 尝试使用embedding缓存
+        embedding_cache = get_embedding_cache_manager()
+        cached_embedding = embedding_cache.get_cached_embedding(question)
+        
+        if cached_embedding is not None:
+            # 使用缓存的embedding进行向量搜索
+            docs_with_scores = self.ddl_collection.similarity_search_with_score_by_vector(
+                embedding=cached_embedding,
+                k=self.n_results
+            )
+        else:
+            # 执行常规的向量搜索(会自动生成embedding)
+            docs_with_scores = self.ddl_collection.similarity_search_with_score(
+                query=question,
+                k=self.n_results
+            )
+            
+            # 获取刚生成的embedding并缓存
+            try:
+                # 通过embedding_function生成向量并缓存
+                generated_embedding = self.embedding_function.generate_embedding(question)
+                if generated_embedding:
+                    embedding_cache.cache_embedding(question, generated_embedding)
+            except Exception as e:
+                print(f"[WARNING] 缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -168,10 +213,31 @@ class PG_VectorStore(VannaBase):
         return filtered_results
 
     def get_related_documentation(self, question: str, **kwargs) -> list:
-        docs_with_scores = self.documentation_collection.similarity_search_with_score(
-            query=question,
-            k=self.n_results
-        )
+        # 尝试使用embedding缓存
+        embedding_cache = get_embedding_cache_manager()
+        cached_embedding = embedding_cache.get_cached_embedding(question)
+        
+        if cached_embedding is not None:
+            # 使用缓存的embedding进行向量搜索
+            docs_with_scores = self.documentation_collection.similarity_search_with_score_by_vector(
+                embedding=cached_embedding,
+                k=self.n_results
+            )
+        else:
+            # 执行常规的向量搜索(会自动生成embedding)
+            docs_with_scores = self.documentation_collection.similarity_search_with_score(
+                query=question,
+                k=self.n_results
+            )
+            
+            # 获取刚生成的embedding并缓存
+            try:
+                # 通过embedding_function生成向量并缓存
+                generated_embedding = self.embedding_function.generate_embedding(question)
+                if generated_embedding:
+                    embedding_cache.cache_embedding(question, generated_embedding)
+            except Exception as e:
+                print(f"[WARNING] 缓存embedding失败: {e}")
 
         results = []
         for doc, score in docs_with_scores:
@@ -520,10 +586,31 @@ class PG_VectorStore(VannaBase):
         self._ensure_error_sql_collection()
         
         try:
-            docs_with_scores = self.error_sql_collection.similarity_search_with_score(
-                query=question,
-                k=self.n_results
-            )
+            # 尝试使用embedding缓存
+            embedding_cache = get_embedding_cache_manager()
+            cached_embedding = embedding_cache.get_cached_embedding(question)
+            
+            if cached_embedding is not None:
+                # 使用缓存的embedding进行向量搜索
+                docs_with_scores = self.error_sql_collection.similarity_search_with_score_by_vector(
+                    embedding=cached_embedding,
+                    k=self.n_results
+                )
+            else:
+                # 执行常规的向量搜索(会自动生成embedding)
+                docs_with_scores = self.error_sql_collection.similarity_search_with_score(
+                    query=question,
+                    k=self.n_results
+                )
+                
+                # 获取刚生成的embedding并缓存
+                try:
+                    # 通过embedding_function生成向量并缓存
+                    generated_embedding = self.embedding_function.generate_embedding(question)
+                    if generated_embedding:
+                        embedding_cache.cache_embedding(question, generated_embedding)
+                except Exception as e:
+                    print(f"[WARNING] 缓存embedding失败: {e}")
             
             results = []
             for doc, score in docs_with_scores: