Ver Fonte

增加提示词,在SQL中增加 NULL LAST,增加embedding尝试的次数.

wangxq há 2 semanas atrás
pai
commit
ed8d605d94

+ 10 - 9
app_config.py

@@ -35,13 +35,13 @@ API_DEEPSEEK_CONFIG = {
 # Qwen模型配置
 API_QIANWEN_CONFIG = {
     "api_key": os.getenv("QWEN_API_KEY"),  # 从环境变量读取API密钥
-    "model": "qwen3-235b-a22b",
+    "model": "qwen-plus",
     "allow_llm_to_see_data": True,
     "temperature": 0.6,
     "n_results": 6,
     "language": "Chinese",
-    "stream": True,  # 是否使用流式模式
-    "enable_thinking": True  # 是否启用思考功能(要求stream=True)
+    "stream": False,  # 是否使用流式模式
+    "enable_thinking": False  # 是否启用思考功能(要求stream=True)
 }
 #qwen3-30b-a3b
 #qwen3-235b-a22b
@@ -50,7 +50,7 @@ API_QIANWEN_CONFIG = {
 
 # ===== API Embedding模型配置 =====
 API_EMBEDDING_CONFIG = {
-    "model_name": "BAAI/bge-m3",
+    "model_name": "text-embedding-v4",
     "api_key": os.getenv("EMBEDDING_API_KEY"),
     "base_url": os.getenv("EMBEDDING_BASE_URL"),
     "embedding_dimension": 1024
@@ -90,8 +90,8 @@ OLLAMA_EMBEDDING_CONFIG = {
 # 应用数据库连接配置 (业务数据库)
 APP_DB_CONFIG = {
     "host": "192.168.67.1",
-    "port": 5432,
-    "dbname": "bank_db",
+    "port": 6432,
+    "dbname": "highway_db",
     "user": os.getenv("APP_DB_USER"),
     "password": os.getenv("APP_DB_PASSWORD")
 }
@@ -103,16 +103,17 @@ APP_DB_CONFIG = {
 PGVECTOR_CONFIG = {
     "host": "192.168.67.1",
     "port": 5432,
-    "dbname": "pgvector_db",
+    "dbname": "highway_pgvector_db",
     "user": os.getenv("PGVECTOR_DB_USER"),
     "password": os.getenv("PGVECTOR_DB_PASSWORD")
 }
 
 # 训练脚本批处理配置
 # 这些配置仅用于 training/run_training.py 训练脚本的批处理优化
-TRAINING_BATCH_PROCESSING_ENABLED = True    # 是否启用训练数据批处理
+# 注意:当使用阿里云等API服务时,建议关闭批处理或设置单线程以避免并发连接错误
+TRAINING_BATCH_PROCESSING_ENABLED = False   # 是否启用训练数据批处理(关闭以避免并发问题)
 TRAINING_BATCH_SIZE = 10                    # 每批处理的训练项目数量
-TRAINING_MAX_WORKERS = 4                    # 训练批处理的最大工作线程数
+TRAINING_MAX_WORKERS = 1                    # 训练批处理的最大工作线程数(设置为1确保单线程)
 
 # 训练数据路径配置
 # 支持以下格式:

+ 4 - 1
core/embedding_function.py

@@ -13,7 +13,7 @@ class EmbeddingFunction:
             "Authorization": f"Bearer {api_key}",
             "Content-Type": "application/json"
         }
-        self.max_retries = 2  # 设置默认的最大重试次数
+        self.max_retries = 3  # 设置默认的最大重试次数
         self.retry_interval = 2  # 设置默认的重试间隔秒数
         self.normalize_embeddings = True # 设置默认是否归一化
 
@@ -161,6 +161,9 @@ class EmbeddingFunction:
                     if self.normalize_embeddings:
                         vector = self._normalize_vector(vector)
                     
+                    # 添加成功生成embedding的debug日志
+                    print(f"[DEBUG] ✓ 成功生成embedding向量,维度: {len(vector)}")
+                    
                     return vector
                 else:
                     error_msg = f"API返回格式异常: {result}"

+ 3 - 2
customembedding/ollama_embedding.py

@@ -8,7 +8,7 @@ class OllamaEmbeddingFunction:
         self.model_name = model_name
         self.base_url = base_url
         self.embedding_dimension = embedding_dimension
-        self.max_retries = 2
+        self.max_retries = 3
         self.retry_interval = 2
 
     def __call__(self, input) -> List[List[float]]:
@@ -87,7 +87,8 @@ class OllamaEmbeddingFunction:
                         else:
                             vector.extend([0.0] * (self.embedding_dimension - actual_dim))
                     
-                    print(f"成功生成Ollama embedding向量,维度: {len(vector)}")
+                    # 添加成功生成embedding的debug日志
+                    print(f"[DEBUG] ✓ 成功生成Ollama embedding向量,维度: {len(vector)}")
                     return vector
                 else:
                     error_msg = f"Ollama API返回格式异常: {result}"

+ 8 - 1
customllm/base_llm_chat.py

@@ -129,7 +129,14 @@ class BaseLLMChat(VannaBase, ABC):
             "4. Please use the most relevant table(s). \n"
             "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
             f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
-            "7. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
+            "7. 在生成 SQL 查询时,如果出现 ORDER BY 子句,请遵循以下规则:\n"
+            "   - 对所有的排序字段(如聚合字段 SUM()、普通列等),请在 ORDER BY 中显式添加 NULLS LAST。\n"
+            "   - 不论是否使用 LIMIT,只要排序字段存在,都必须添加 NULLS LAST,以防止 NULL 排在结果顶部。\n"
+            "   - 示例参考:\n"
+            "     - ORDER BY total DESC NULLS LAST\n"
+            "     - ORDER BY zf_order DESC NULLS LAST\n"
+            "     - ORDER BY SUM(c.customer_count) DESC NULLS LAST \n"
+            "8. 【重要】请在SQL查询中为所有SELECT的列都使用中文别名:\n"
             "   - 每个列都必须使用 AS 中文别名 的格式,没有例外\n"
             "   - 包括原始字段名也要添加中文别名,例如:SELECT gender AS 性别, card_category AS 卡片类型\n"
             "   - 计算字段也要有中文别名,例如:SELECT COUNT(*) AS 持卡人数\n"