Bläddra i källkod

增加问题重新与合并功能,增加阈值过滤功能。

wangxq 3 veckor sedan
förälder
incheckning
841074877b

+ 18 - 4
app_config.py

@@ -2,14 +2,15 @@ from dotenv import load_dotenv
 import os
 
 # 加载.env文件中的环境变量
-load_dotenv()
+# 使用 override=True 确保能够重新加载更新的环境变量
+load_dotenv(override=True)
 
 # ===== 模型提供商类型配置 =====
 # LLM模型提供商类型:api 或 ollama
-LLM_MODEL_TYPE = "ollama"  # api, ollama
+LLM_MODEL_TYPE = "api"  # api, ollama
 
 # Embedding模型提供商类型:api 或 ollama  
-EMBEDDING_MODEL_TYPE = "ollama"  # api, ollama
+EMBEDDING_MODEL_TYPE = "api"  # api, ollama
 
 # ===== 模型名称配置 =====
 # API LLM模型名称(当LLM_MODEL_TYPE="api"时使用:qwen 或 deepseek)
@@ -53,6 +54,8 @@ API_EMBEDDING_CONFIG = {
     "embedding_dimension": 1024
 }
 
+# BAAI/bge-m3
+# text-embedding-v4
 
 # ===== Ollama LLM模型配置 =====
 OLLAMA_LLM_CONFIG = {
@@ -114,4 +117,15 @@ TRAINING_MAX_WORKERS = 4                    # 训练批处理的最大工作线
 # 3. 相对路径(不以.开头):
 #    "training/data"       - 相对于项目根目录
 #    "my_data"             - 项目根目录下的my_data文件夹
-TRAINING_DATA_PATH = "./training/data"
+TRAINING_DATA_PATH = "./training/data"
+
+# 是否启用问题重写功能,也就是上下文问题合并。
+REWRITE_QUESTION_ENABLED = False
+
+# 是否启用向量查询结果得分阈值过滤
+# result = max((n + 1) // 2, 1)
+ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True
+# 向量查询结果得分阈值
+RESULT_VECTOR_SQL_SCORE_THRESHOLD = 0.65
+RESULT_VECTOR_DDL_SCORE_THRESHOLD = 0.5
+RESULT_VECTOR_DOC_SCORE_THRESHOLD = 0.5

+ 43 - 1
customdeepseek/custom_deepseek_chat.py

@@ -3,6 +3,8 @@ import os
 from openai import OpenAI
 from vanna.base import VannaBase
 #from base import VannaBase
+# 导入配置参数
+from app_config import REWRITE_QUESTION_ENABLED
 
 
 # from vanna.chromadb import ChromaDB_VectorStore
@@ -168,4 +170,44 @@ class DeepSeekChat(VannaBase):
             return response
         except Exception as e:
             print(f"[ERROR] LLM对话失败: {str(e)}")
-            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+
+    def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
+        """
+        重写问题合并方法,通过配置参数控制是否启用合并功能
+        
+        Args:
+            last_question (str): 上一个问题
+            new_question (str): 新问题
+            **kwargs: 其他参数
+            
+        Returns:
+            str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
+        """
+        # 如果未启用合并功能或没有上一个问题,直接返回新问题
+        if not REWRITE_QUESTION_ENABLED or last_question is None:
+            print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
+            return new_question
+        
+        print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
+        print(f"[DEBUG] 上一个问题: {last_question}")
+        print(f"[DEBUG] 新问题: {new_question}")
+        
+        try:
+            prompt = [
+                self.system_message(
+                    "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
+                    "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
+                    "请用中文回答。"
+                ),
+                self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
+            ]
+            
+            rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
+            print(f"[DEBUG] 合并后的问题: {rewritten_question}")
+            return rewritten_question
+            
+        except Exception as e:
+            print(f"[ERROR] 问题合并失败: {str(e)}")
+            # 如果合并失败,返回新问题
+            return new_question

+ 43 - 1
customollama/ollama_chat.py

@@ -2,6 +2,8 @@ import requests
 import json
 from vanna.base import VannaBase
 from typing import List, Dict, Any
+# 导入配置参数
+from app_config import REWRITE_QUESTION_ENABLED
 
 class OllamaChat(VannaBase):
     def __init__(self, config=None):
@@ -162,4 +164,44 @@ class OllamaChat(VannaBase):
             
         except Exception as e:
             result["message"] = f"Ollama连接测试失败: {str(e)}"
-            return result 
+            return result 
+
+    def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
+        """
+        重写问题合并方法,通过配置参数控制是否启用合并功能
+        
+        Args:
+            last_question (str): 上一个问题
+            new_question (str): 新问题
+            **kwargs: 其他参数
+            
+        Returns:
+            str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
+        """
+        # 如果未启用合并功能或没有上一个问题,直接返回新问题
+        if not REWRITE_QUESTION_ENABLED or last_question is None:
+            print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
+            return new_question
+        
+        print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
+        print(f"[DEBUG] 上一个问题: {last_question}")
+        print(f"[DEBUG] 新问题: {new_question}")
+        
+        try:
+            prompt = [
+                self.system_message(
+                    "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
+                    "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
+                    "请用中文回答。"
+                ),
+                self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
+            ]
+            
+            rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
+            print(f"[DEBUG] 合并后的问题: {rewritten_question}")
+            return rewritten_question
+            
+        except Exception as e:
+            print(f"[ERROR] 问题合并失败: {str(e)}")
+            # 如果合并失败,返回新问题
+            return new_question 

+ 152 - 6
custompgvector/pgvector.py

@@ -95,17 +95,163 @@ class PG_VectorStore(VannaBase):
             case _:
                 raise ValueError("Specified collection does not exist.")
 
+    # def get_similar_question_sql(self, question: str) -> list:
+    #     documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
+    #     return [ast.literal_eval(document.page_content) for document in documents]
+
+    # 在原来的基础之上,增加相似度的值。
     def get_similar_question_sql(self, question: str) -> list:
-        documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
-        return [ast.literal_eval(document.page_content) for document in documents]
+        docs_with_scores = self.sql_collection.similarity_search_with_score(
+            query=question,
+            k=self.n_results
+        )
+
+        results = []
+        for doc, score in docs_with_scores:
+            # 将文档内容转换为 dict
+            base = ast.literal_eval(doc.page_content)
+
+            # 计算相似度
+            similarity = round(1 - score, 4)
+
+            # 每条记录单独打印
+            print(f"[DEBUG] SQL Match: {base.get('question', '')} | similarity: {similarity}")
+
+            # 添加 similarity 字段
+            base["similarity"] = similarity
+            results.append(base)
+
+        # 应用阈值过滤
+        filtered_results = self._apply_score_threshold_filter(
+            results, 
+            "RESULT_VECTOR_SQL_SCORE_THRESHOLD",
+            "SQL"
+        )
+
+        return filtered_results
 
     def get_related_ddl(self, question: str, **kwargs) -> list:
-        documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
-        return [document.page_content for document in documents]
+        docs_with_scores = self.ddl_collection.similarity_search_with_score(
+            query=question,
+            k=self.n_results
+        )
+
+        results = []
+        for doc, score in docs_with_scores:
+            # 计算相似度
+            similarity = round(1 - score, 4)
+
+            # 每条记录单独打印
+            print(f"[DEBUG] DDL Match: {doc.page_content[:50]}... | similarity: {similarity}")
+
+            # 添加 similarity 字段
+            result = {
+                "content": doc.page_content,
+                "similarity": similarity
+            }
+            results.append(result)
+
+        # 应用阈值过滤
+        filtered_results = self._apply_score_threshold_filter(
+            results, 
+            "RESULT_VECTOR_DDL_SCORE_THRESHOLD",
+            "DDL"
+        )
+
+        return filtered_results
 
     def get_related_documentation(self, question: str, **kwargs) -> list:
-        documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
-        return [document.page_content for document in documents]
+        docs_with_scores = self.documentation_collection.similarity_search_with_score(
+            query=question,
+            k=self.n_results
+        )
+
+        results = []
+        for doc, score in docs_with_scores:
+            # 计算相似度
+            similarity = round(1 - score, 4)
+
+            # 每条记录单独打印
+            print(f"[DEBUG] Doc Match: {doc.page_content[:50]}... | similarity: {similarity}")
+
+            # 添加 similarity 字段
+            result = {
+                "content": doc.page_content,
+                "similarity": similarity
+            }
+            results.append(result)
+
+        # 应用阈值过滤
+        filtered_results = self._apply_score_threshold_filter(
+            results, 
+            "RESULT_VECTOR_DOC_SCORE_THRESHOLD",
+            "DOC"
+        )
+
+        return filtered_results
+
+    def _apply_score_threshold_filter(self, results: list, threshold_config_key: str, result_type: str) -> list:
+        """
+        应用相似度阈值过滤逻辑
+        
+        Args:
+            results: 原始结果列表,每个元素包含 similarity 字段
+            threshold_config_key: 配置中的阈值参数名
+            result_type: 结果类型(用于日志)
+            
+        Returns:
+            过滤后的结果列表
+        """
+        if not results:
+            return results
+            
+        # 导入配置
+        try:
+            import app_config
+            enable_threshold = getattr(app_config, 'ENABLE_RESULT_VECTOR_SCORE_THRESHOLD', False)
+            threshold = getattr(app_config, threshold_config_key, 0.65)
+        except (ImportError, AttributeError) as e:
+            print(f"[WARNING] 无法加载阈值配置: {e},使用默认值")
+            enable_threshold = False
+            threshold = 0.65
+        
+        # 如果未启用阈值过滤,直接返回原结果
+        if not enable_threshold:
+            print(f"[DEBUG] {result_type} 阈值过滤未启用,返回全部 {len(results)} 条结果")
+            return results
+        
+        total_count = len(results)
+        min_required = max((total_count + 1) // 2, 1)
+        
+        print(f"[DEBUG] {result_type} 阈值过滤: 总数={total_count}, 阈值={threshold}, 最少保留={min_required}")
+        
+        # 按相似度降序排序(确保最相似的在前面)
+        sorted_results = sorted(results, key=lambda x: x.get('similarity', 0), reverse=True)
+        
+        # 找出满足阈值的结果
+        above_threshold = [r for r in sorted_results if r.get('similarity', 0) >= threshold]
+        
+        # 应用过滤逻辑
+        if len(above_threshold) >= min_required:
+            # 情况1: 满足阈值的结果数量 >= 最少保留数量,返回满足阈值的结果
+            filtered_results = above_threshold
+            filtered_count = len(above_threshold)
+            print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (全部满足阈值)")
+        else:
+            # 情况2: 满足阈值的结果数量 < 最少保留数量,强制保留前 min_required 条
+            filtered_results = sorted_results[:min_required]
+            above_count = len(above_threshold)
+            below_count = min_required - above_count
+            filtered_count = min_required
+            print(f"[DEBUG] {result_type} 过滤结果: 保留 {filtered_count} 条, 过滤掉 {total_count - filtered_count} 条 (满足阈值: {above_count}, 强制保留: {below_count})")
+        
+        # 打印过滤详情
+        for i, result in enumerate(filtered_results):
+            similarity = result.get('similarity', 0)
+            status = "✓" if similarity >= threshold else "✗"
+            print(f"[DEBUG] {result_type} 保留 {i+1}: similarity={similarity} {status}")
+        
+        return filtered_results
 
     def train(
         self,

+ 65 - 4
customqianwen/Custom_QianwenAI_chat.py

@@ -1,6 +1,8 @@
 import os
 from openai import OpenAI
 from vanna.base import VannaBase
+# 导入配置参数
+from app_config import REWRITE_QUESTION_ENABLED
 
 
 class QianWenAI_Chat(VannaBase):
@@ -61,15 +63,33 @@ class QianWenAI_Chat(VannaBase):
             initial_prompt = f"You are a {self.dialect} expert. " + \
             "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
 
+        # 提取DDL内容(适配新的字典格式)
+        ddl_content_list = []
+        if ddl_list:
+            for item in ddl_list:
+                if isinstance(item, dict) and "content" in item:
+                    ddl_content_list.append(item["content"])
+                elif isinstance(item, str):
+                    ddl_content_list.append(item)
+        
         initial_prompt = self.add_ddl_to_prompt(
-            initial_prompt, ddl_list, max_tokens=self.max_tokens
+            initial_prompt, ddl_content_list, max_tokens=self.max_tokens
         )
 
+        # 提取文档内容(适配新的字典格式)
+        doc_content_list = []
+        if doc_list:
+            for item in doc_list:
+                if isinstance(item, dict) and "content" in item:
+                    doc_content_list.append(item["content"])
+                elif isinstance(item, str):
+                    doc_content_list.append(item)
+        
         if self.static_documentation != "":
-            doc_list.append(self.static_documentation)
+            doc_content_list.append(self.static_documentation)
 
         initial_prompt = self.add_documentation_to_prompt(
-            initial_prompt, doc_list, max_tokens=self.max_tokens
+            initial_prompt, doc_content_list, max_tokens=self.max_tokens
         )
 
         initial_prompt += (
@@ -87,6 +107,7 @@ class QianWenAI_Chat(VannaBase):
             "   - 中文别名要准确反映字段的业务含义\n"
             "   - 绝对不能有任何字段没有中文别名,这会影响表格的可读性\n"
             "   - 这样可以提高图表的可读性和用户体验\n"
+            "   - 不要在where条件中使用中文别名,比如: WHERE gender = 'F' AS 性别, 这是错误的语法\n"
             "   正确示例:SELECT gender AS 性别, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
             "   错误示例:SELECT gender, card_category AS 卡片类型, COUNT(*) AS 持卡人数 FROM table_name\n"
         )
@@ -399,4 +420,44 @@ class QianWenAI_Chat(VannaBase):
             return response
         except Exception as e:
             print(f"[ERROR] LLM对话失败: {str(e)}")
-            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+
+    def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
+        """
+        重写问题合并方法,通过配置参数控制是否启用合并功能
+        
+        Args:
+            last_question (str): 上一个问题
+            new_question (str): 新问题
+            **kwargs: 其他参数
+            
+        Returns:
+            str: 如果启用合并且问题相关则返回合并后的问题,否则返回新问题
+        """
+        # 如果未启用合并功能或没有上一个问题,直接返回新问题
+        if not REWRITE_QUESTION_ENABLED or last_question is None:
+            print(f"[DEBUG] 问题合并功能{'未启用' if not REWRITE_QUESTION_ENABLED else '上一个问题为空'},直接返回新问题")
+            return new_question
+        
+        print(f"[DEBUG] 启用问题合并功能,尝试合并问题")
+        print(f"[DEBUG] 上一个问题: {last_question}")
+        print(f"[DEBUG] 新问题: {new_question}")
+        
+        try:
+            prompt = [
+                self.system_message(
+                    "你的目标是将一系列相关的问题合并成一个单一的问题。如果第二个问题与第一个问题无关且完全独立,则返回第二个问题。"
+                    "只返回新的合并问题,不要添加任何额外的解释。该问题理论上应该能够用一个SQL语句来回答。"
+                    "请用中文回答。"
+                ),
+                self.user_message(f"第一个问题: {last_question}\n第二个问题: {new_question}")
+            ]
+            
+            rewritten_question = self.submit_prompt(prompt=prompt, **kwargs)
+            print(f"[DEBUG] 合并后的问题: {rewritten_question}")
+            return rewritten_question
+            
+        except Exception as e:
+            print(f"[ERROR] 问题合并失败: {str(e)}")
+            # 如果合并失败,返回新问题
+            return new_question

+ 18 - 2
customqianwen/Custom_QiawenAI_chat_cn.py

@@ -224,7 +224,15 @@ class QianWenAI_Chat_CN(VannaBase):
 
         # 添加相关的DDL(如果有)
         if ddl_list and len(ddl_list) > 0:
-            ddl_text = "\n\n".join([f"-- DDL项 {i+1}:\n{ddl}" for i, ddl in enumerate(ddl_list)])
+            ddl_items = []
+            for i, item in enumerate(ddl_list):
+                if isinstance(item, dict) and "content" in item:
+                    similarity_info = f" (相似度: {item.get('similarity', 'N/A')})" if "similarity" in item else ""
+                    ddl_items.append(f"-- DDL项 {i+1}{similarity_info}:\n{item['content']}")
+                elif isinstance(item, str):
+                    ddl_items.append(f"-- DDL项 {i+1}:\n{item}")
+            
+            ddl_text = "\n\n".join(ddl_items)
             messages.append(
                 self.user_message(
                     f"""
@@ -239,7 +247,15 @@ class QianWenAI_Chat_CN(VannaBase):
 
         # 添加相关的文档(如果有)
         if doc_list and len(doc_list) > 0:
-            doc_text = "\n\n".join([f"-- 文档项 {i+1}:\n{doc}" for i, doc in enumerate(doc_list)])
+            doc_items = []
+            for i, item in enumerate(doc_list):
+                if isinstance(item, dict) and "content" in item:
+                    similarity_info = f" (相似度: {item.get('similarity', 'N/A')})" if "similarity" in item else ""
+                    doc_items.append(f"-- 文档项 {i+1}{similarity_info}:\n{item['content']}")
+                elif isinstance(item, str):
+                    doc_items.append(f"-- 文档项 {i+1}:\n{item}")
+            
+            doc_text = "\n\n".join(doc_items)
             messages.append(
                 self.user_message(
                     f"""

+ 14 - 0
docs/release.md

@@ -0,0 +1,14 @@
+1.增加了开关 REWRITE_QUESTION_ENABLED = False,用于控制是否启用问题重写功能,也就是上下文问题合并。
+"I interpreted your question as"
+
+
+2.增加了开关 ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True,用于控制是否启用向量查询结果得分阈值过滤。
+
+# 是否启用向量查询结果得分阈值过滤
+# result = max((n + 1) // 2, 1)
+ENABLE_RESULT_VECTOR_SCORE_THRESHOLD = True
+# 向量查询结果得分阈值
+RESULT_VECTOR_SQL_SCORE_THRESHOLD = 0.65
+RESULT_VECTOR_DDL_SCORE_THRESHOLD = 0.5
+RESULT_VECTOR_DOC_SCORE_THRESHOLD = 0.5
+

+ 12 - 49
embedding_function.py

@@ -49,35 +49,14 @@ class EmbeddingFunction:
             
         embeddings = []
         for text in input:
-            payload = {
-                "model": self.model_name,
-                "input": text,
-                "encoding_format": "float"
-            }
-            
+            # 直接调用generate_embedding,让它处理异常
             try:
-                # 修复URL拼接问题
-                url = self.base_url
-                if not url.endswith("/embeddings"):
-                    url = url.rstrip("/")  # 移除尾部斜杠,避免双斜杠
-                    if not url.endswith("/v1/embeddings"):
-                        url = f"{url}/embeddings"
-                
-                response = requests.post(url, json=payload, headers=self.headers)
-                response.raise_for_status()
-                
-                result = response.json()
-                
-                if "data" in result and len(result["data"]) > 0:
-                    vector = result["data"][0]["embedding"]
-                    embeddings.append(vector)
-                else:
-                    raise ValueError(f"API返回无效: {result}")
-                    
+                vector = self.generate_embedding(text)
+                embeddings.append(vector)
             except Exception as e:
-                print(f"获取embedding时出错: {e}")
-                # 使用实例的 embedding_dimension 来创建零向量
-                embeddings.append([0.0] * self.embedding_dimension)
+                print(f"为文本 '{text}' 生成embedding失败: {e}")
+                # 重新抛出异常,不返回零向量
+                raise e
                 
         return embeddings
     
@@ -115,16 +94,10 @@ class EmbeddingFunction:
         Returns:
             List[float]: 嵌入向量
         """
-        print(f"生成嵌入向量,文本长度: {len(text)} 字符")
-        
         # 处理空文本
         if not text or len(text.strip()) == 0:
-            print("输入文本为空,返回零向量")
-            # self.embedding_dimension 在初始化时已被强制要求
-            # 因此不应该为 None 或需要默认值
+            # 空文本返回零向量是合理的行为
             if self.embedding_dimension is None:
-                # 这个分支理论上不应该被执行,因为工厂函数会确保 embedding_dimension 已设置
-                # 但为了健壮性,如果它意外地是 None,则抛出错误
                 raise ValueError("Embedding dimension (self.embedding_dimension) 未被正确初始化。")
             return [0.0] * self.embedding_dimension
         
@@ -145,7 +118,6 @@ class EmbeddingFunction:
                     url = url.rstrip("/")  # 移除尾部斜杠,避免双斜杠
                     if not url.endswith("/v1/embeddings"):
                         url = f"{url}/embeddings"
-                print(f"请求URL: {url}")
                 
                 response = requests.post(
                     url, 
@@ -157,14 +129,13 @@ class EmbeddingFunction:
                 # 检查响应状态
                 if response.status_code != 200:
                     error_msg = f"API请求错误: {response.status_code}, {response.text}"
-                    print(error_msg)
                     
                     # 根据错误码判断是否需要重试
                     if response.status_code in (429, 500, 502, 503, 504):
                         retries += 1
                         if retries <= self.max_retries:
                             wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
-                            print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            print(f"API请求失败,等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                             time.sleep(wait_time)
                             continue
                     
@@ -180,39 +151,31 @@ class EmbeddingFunction:
                     # 如果是首次调用且未提供维度,则自动设置
                     if self.embedding_dimension is None:
                         self.embedding_dimension = len(vector)
-                        print(f"自动设置embedding维度为: {self.embedding_dimension}")
                     else:
                         # 验证向量维度
                         actual_dim = len(vector)
                         if actual_dim != self.embedding_dimension:
-                            print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                            print(f"警告: 向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
                     
                     # 如果需要归一化
                     if self.normalize_embeddings:
                         vector = self._normalize_vector(vector)
                     
-                    print(f"成功生成embedding向量,维度: {len(vector)}")
                     return vector
                 else:
                     error_msg = f"API返回格式异常: {result}"
-                    print(error_msg)
                     raise ValueError(error_msg)
                 
             except Exception as e:
-                print(f"生成embedding时出错: {str(e)}")
                 retries += 1
                 
                 if retries <= self.max_retries:
                     wait_time = self.retry_interval * (2 ** (retries - 1))  # 指数退避
-                    print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    print(f"生成embedding时出错: {str(e)}, 等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
                     time.sleep(wait_time)
                 else:
-                    print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
-                    # 决定是返回零向量还是重新抛出异常
-                    if self.embedding_dimension:
-                        print(f"返回零向量 (维度: {self.embedding_dimension})")
-                        return [0.0] * self.embedding_dimension
-                    raise
+                    # 抛出异常而不是返回零向量,确保问题不被掩盖
+                    raise RuntimeError(f"生成embedding失败,已重试{self.max_retries}次: {str(e)}")
         
         # 这里不应该到达,但为了完整性添加
         raise RuntimeError("生成embedding失败")