Explorar o código

修复了ollama/deepseek llm对stream和thinking的支持,正在修改ollama.

wangxq hai 3 semanas
pai
achega
ecb9b74349

+ 31 - 2
README.md

@@ -1,9 +1,38 @@
+# Vanna-Chainlit-Chromadb 项目
+
+## 项目结构
+
+该项目主要组织结构如下:
+
+- **core/**: 核心组件目录
+  - **embedding_function.py**: 嵌入函数实现
+  - **vanna_llm_factory.py**: Vanna实例工厂
+- **common/**: 通用工具和辅助函数
+- **customembedding/**: 自定义嵌入模型实现
+- **customllm/**: 自定义语言模型实现
+- **custompgvector/**: PgVector数据库集成
+- **docs/**: 项目文档
+- **public/**: 公共资源文件
+- **training/**: 训练工具和数据
+- **app_config.py**: 应用配置
+- **chainlit_app.py**: Chainlit应用入口
+- **flask_app.py**: Flask应用入口
+
+## 训练数据与Function的对应关系
 
-更新后的训练数据与Function的对应关系
 | 文件格式/扩展名 | 对应处理函数 | 用途说明 |
 |----------------|-------------|----------|
 | .ddl | train_ddl_statements() | 训练数据库定义语言文件 |
 | .md / .markdown | train_documentation_blocks() | 训练Markdown格式的文档 |
 | _pair.json / _pairs.json | train_json_question_sql_pairs() | 训练JSON格式的问答对 |
 | _pair.sql / _pairs.sql | train_formatted_question_sql_pairs() | 训练格式化的问答对文件 |
-| .sql (其他) | train_sql_examples() | 训练一般SQL示例文件 |
+| .sql (其他) | train_sql_examples() | 训练一般SQL示例文件 |
+
+
+各种组合的行为总结
+enable_thinking	stream (输入)	stream (实际)	行为描述
+False	False	False	非流式模式,无thinking
+False	True	True	流式模式,无thinking
+True	False	True (强制)	流式模式,有thinking + 警告日志
+True	True	True	流式模式,有thinking
+当前的代码实现完全符合您的两个要求,逻辑正确且健壮!

+ 16 - 6
app_config.py

@@ -13,8 +13,8 @@ LLM_MODEL_TYPE = "api"  # api, ollama
 EMBEDDING_MODEL_TYPE = "api"  # api, ollama
 
 # ===== 模型名称配置 =====
-# API LLM模型名称(当LLM_MODEL_TYPE="api"时使用:qwen 或 deepseek)
-API_LLM_MODEL = "qwen"
+# API LLM模型名称(当LLM_MODEL_TYPE="api"时使用:qianwen 或 deepseek 
+API_LLM_MODEL = "deepseek"
 
 # 向量数据库类型:chromadb 或 pgvector
 VECTOR_DB_TYPE = "pgvector"
@@ -28,18 +28,20 @@ API_DEEPSEEK_CONFIG = {
     "temperature": 0.7,
     "n_results": 6,
     "language": "Chinese",
+    "stream": True,  # 是否使用流式模式
     "enable_thinking": False  # 自定义,是否支持流模式
 }
 
 # Qwen模型配置
-API_QWEN_CONFIG = {
+API_QIANWEN_CONFIG = {
     "api_key": os.getenv("QWEN_API_KEY"),  # 从环境变量读取API密钥
-    "model": "qwen-plus",
+    "model": "qwen3-235b-a22b",
     "allow_llm_to_see_data": True,
     "temperature": 0.7,
     "n_results": 6,
     "language": "Chinese",
-    "enable_thinking": False #自定义,是否支持流模式,仅qwen3模型。
+    "stream": True,  # 是否使用流式模式
+    "enable_thinking": True  # 是否启用思考功能(要求stream=True)
 }
 #qwen3-30b-a3b
 #qwen3-235b-a22b
@@ -65,7 +67,15 @@ OLLAMA_LLM_CONFIG = {
     "temperature": 0.7,
     "n_results": 6,
     "language": "Chinese",
-    "timeout": 60  # Ollama可能需要更长超时时间
+    "timeout": 60,  # Ollama可能需要更长超时时间
+    "stream": True,  # 是否使用流式模式
+    "enable_thinking": True,  # 是否启用思考功能(推理模型支持)
+    
+    # Ollama 特定参数
+    #"num_ctx": 8192,  # 上下文长度
+    #"num_predict": 2048,  # 预测token数量,-1表示无限制
+    #"repeat_penalty": 1.1,  # 重复惩罚
+    #"auto_check_connection": True  # 是否自动检查连接
 }
 
 

+ 1 - 1
chainlit_app.py

@@ -1,6 +1,6 @@
 import chainlit as cl
 from chainlit.input_widget import Select
-from vanna_llm_factory import create_vanna_instance
+from core.vanna_llm_factory import create_vanna_instance
 import os
 
 # vn.set_api_key(os.environ['VANNA_API_KEY'])

+ 1 - 1
citu_app.py

@@ -1,6 +1,6 @@
 # 给dataops 对话助手返回结果
 from vanna.flask import VannaFlaskApp
-from vanna_llm_factory import create_vanna_instance
+from core.vanna_llm_factory import create_vanna_instance
 from flask import request, jsonify
 import pandas as pd
 import common.result as result

+ 2 - 2
common/utils.py

@@ -46,8 +46,8 @@ def get_current_llm_config():
     if app_config.LLM_MODEL_TYPE == "ollama":
         return app_config.OLLAMA_LLM_CONFIG
     elif app_config.LLM_MODEL_TYPE == "api":
-        if app_config.API_LLM_MODEL == "qwen":
-            return app_config.API_QWEN_CONFIG
+        if app_config.API_LLM_MODEL == "qianwen":
+            return app_config.API_QIANWEN_CONFIG
         elif app_config.API_LLM_MODEL == "deepseek":
             return app_config.API_DEEPSEEK_CONFIG
         else:

+ 5 - 0
core/__init__.py

@@ -0,0 +1,5 @@
+"""
+Core package - 系统核心组件
+
+包含嵌入函数和Vanna实例创建等核心功能
+""" 

+ 1 - 1
embedding_function.py → core/embedding_function.py

@@ -314,4 +314,4 @@ def get_embedding_function():
             api_key=api_key,
             base_url=base_url,
             embedding_dimension=embedding_dimension
-        )
+        ) 

+ 3 - 3
vanna_llm_factory.py → core/vanna_llm_factory.py

@@ -2,7 +2,7 @@
 Vanna LLM 工厂文件,支持多种LLM提供商和向量数据库
 """
 import app_config, os
-from embedding_function import get_embedding_function
+from core.embedding_function import get_embedding_function
 from common.vanna_combinations import get_vanna_class, print_available_combinations
 
 def create_vanna_instance(config_module=None):
@@ -61,7 +61,7 @@ def create_vanna_instance(config_module=None):
     
     # 配置向量数据库
     if model_info["vector_db"] == "chromadb":
-        config["path"] = os.path.dirname(os.path.abspath(__file__))
+        config["path"] = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # 返回项目根目录
         print(f"已配置使用ChromaDB,路径:{config['path']}")
     elif model_info["vector_db"] == "pgvector":
         # 构建PostgreSQL连接字符串
@@ -84,4 +84,4 @@ def create_vanna_instance(config_module=None):
           f"{config_module.APP_DB_CONFIG['port']}/"
           f"{config_module.APP_DB_CONFIG['dbname']}")
     
-    return vn
+    return vn 

+ 1 - 1
customllm/base_llm_chat.py

@@ -301,7 +301,7 @@ class BaseLLMChat(VannaBase, ABC):
                 print(f"[WARNING] 返回内容不像有效SQL: {sql}")
                 return None
                 
-            print(f"[SUCCESS] 成功生成SQL: {sql}")
+            print(f"[SUCCESS] 成功生成SQL:\n {sql}")
             return sql
             
         except Exception as e:

+ 115 - 11
customllm/deepseek_chat.py

@@ -32,6 +32,19 @@ class DeepSeekChat(BaseLLMChat):
         for message in prompt:
             num_tokens += len(message["content"]) / 4
 
+        # 获取 stream 参数
+        stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
+        
+        # 获取 enable_thinking 参数
+        enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
+
+        # DeepSeek API约束:enable_thinking=True时建议使用stream=True
+        # 如果stream=False但enable_thinking=True,则忽略enable_thinking
+        if enable_thinking and not stream_mode:
+            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            enable_thinking = False
+
+        # 确定使用的模型
         model = None
         if kwargs.get("model", None) is not None:
             model = kwargs.get("model", None)
@@ -42,19 +55,110 @@ class DeepSeekChat(BaseLLMChat):
         elif self.config is not None and "model" in self.config:
             model = self.config["model"]
         else:
-            if num_tokens > 3500:
-                model = "deepseek-chat"
+            # 根据 enable_thinking 选择模型
+            if enable_thinking:
+                model = "deepseek-reasoner"
             else:
-                model = "deepseek-chat"
+                if num_tokens > 3500:
+                    model = "deepseek-chat"
+                else:
+                    model = "deepseek-chat"
+
+        # 模型兼容性提示(但不强制切换)
+        if enable_thinking and model not in ["deepseek-reasoner"]:
+            print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
 
         print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
+
+        # 构建 API 调用参数
+        api_params = {
+            "model": model,
+            "messages": prompt,
+            "stop": None,
+            "temperature": self.temperature,
+            "stream": stream_mode,
+        }
+
+        # 过滤掉自定义参数,避免传递给 API
+        filtered_kwargs = {k: v for k, v in kwargs.items() 
+                          if k not in ['model', 'engine', 'enable_thinking', 'stream']}
+
+        # 根据模型过滤不支持的参数
+        if model == "deepseek-reasoner":
+            # deepseek-reasoner 不支持的参数
+            unsupported_params = ['top_p', 'presence_penalty', 'frequency_penalty', 'logprobs', 'top_logprobs']
+            for param in unsupported_params:
+                if param in filtered_kwargs:
+                    print(f"警告:deepseek-reasoner 不支持参数 {param},已忽略")
+                    filtered_kwargs.pop(param, None)
+        else:
+            # deepseek-chat 等其他模型,只过滤明确会导致错误的参数
+            # 目前 deepseek-chat 支持大部分标准参数,暂不过滤
+            pass
 
-        # DeepSeek不支持thinking功能,忽略enable_thinking参数
-        response = self.client.chat.completions.create(
-            model=model,
-            messages=prompt,
-            stop=None,
-            temperature=self.temperature,
-        )
+        # 添加其他参数
+        api_params.update(filtered_kwargs)
 
-        return response.choices[0].message.content 
+        if stream_mode:
+            # 流式处理模式
+            if model == "deepseek-reasoner" and enable_thinking:
+                print("使用流式处理模式,启用推理功能")
+            else:
+                print("使用流式处理模式,常规聊天")
+            
+            response_stream = self.client.chat.completions.create(**api_params)
+            
+            if model == "deepseek-reasoner" and enable_thinking:
+                # 推理模型的流式处理
+                collected_reasoning = []
+                collected_content = []
+                
+                for chunk in response_stream:
+                    if hasattr(chunk, 'choices') and chunk.choices:
+                        delta = chunk.choices[0].delta
+                        
+                        # 收集推理内容
+                        if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
+                            collected_reasoning.append(delta.reasoning_content)
+                        
+                        # 收集最终答案
+                        if hasattr(delta, 'content') and delta.content:
+                            collected_content.append(delta.content)
+                
+                # 可选:打印推理过程
+                if collected_reasoning:
+                    print("Model reasoning process:\n", "".join(collected_reasoning))
+                
+                return "".join(collected_content)
+            else:
+                # 其他模型的流式处理(如 deepseek-chat)
+                collected_content = []
+                for chunk in response_stream:
+                    if hasattr(chunk, 'choices') and chunk.choices:
+                        delta = chunk.choices[0].delta
+                        if hasattr(delta, 'content') and delta.content:
+                            collected_content.append(delta.content)
+                
+                return "".join(collected_content)
+        else:
+            # 非流式处理模式
+            if model == "deepseek-reasoner" and enable_thinking:
+                print("使用非流式处理模式,启用推理功能")
+            else:
+                print("使用非流式处理模式,常规聊天")
+            
+            response = self.client.chat.completions.create(**api_params)
+            
+            if model == "deepseek-reasoner" and enable_thinking:
+                # 推理模型的非流式处理
+                message = response.choices[0].message
+                
+                # 可选:打印推理过程
+                if hasattr(message, 'reasoning_content') and message.reasoning_content:
+                    print("Model reasoning process:\n", message.reasoning_content)
+                
+                return message.content
+            else:
+                # 其他模型的非流式处理(如 deepseek-chat)
+                return response.choices[0].message.content 

+ 362 - 22
customllm/ollama_chat.py

@@ -1,6 +1,7 @@
 import requests
 import json
-from typing import List, Dict, Any
+import re
+from typing import List, Dict, Any, Optional
 from .base_llm_chat import BaseLLMChat
 
 
@@ -12,9 +13,32 @@ class OllamaChat(BaseLLMChat):
         super().__init__(config=config)
 
         # Ollama特定的配置参数
-        self.base_url = config.get("base_url", "http://localhost:11434")
-        self.model = config.get("model", "qwen2.5:7b")
-        self.timeout = config.get("timeout", 60)
+        self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
+        self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
+        self.timeout = config.get("timeout", 60) if config else 60
+        
+        # Ollama 特定参数
+        self.num_ctx = config.get("num_ctx", 4096) if config else 4096  # 上下文长度
+        self.num_predict = config.get("num_predict", -1) if config else -1  # 预测token数量
+        self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1  # 重复惩罚
+        
+        # 验证连接
+        if config and config.get("auto_check_connection", True):
+            self._check_ollama_health()
+
+    def _check_ollama_health(self) -> bool:
+        """检查 Ollama 服务健康状态"""
+        try:
+            response = requests.get(f"{self.base_url}/api/tags", timeout=5)
+            if response.status_code == 200:
+                print(f"✅ Ollama 服务连接正常: {self.base_url}")
+                return True
+            else:
+                print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
+                return False
+        except requests.exceptions.RequestException as e:
+            print(f"❌ Ollama 服务连接失败: {e}")
+            return False
 
     def submit_prompt(self, prompt, **kwargs) -> str:
         if prompt is None:
@@ -28,38 +52,127 @@ class OllamaChat(BaseLLMChat):
         for message in prompt:
             num_tokens += len(message["content"]) / 4
 
-        # 确定使用的模型
-        model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
+        # 获取 stream 参数
+        stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
+        
+        # 获取 enable_thinking 参数
+        enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
+
+        # Ollama 约束:enable_thinking=True时建议使用stream=True
+        # 如果stream=False但enable_thinking=True,则忽略enable_thinking
+        if enable_thinking and not stream_mode:
+            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            enable_thinking = False
+
+        # 智能模型选择
+        model = self._determine_model(kwargs, enable_thinking, num_tokens)
+
+        # 检查是否为推理模型
+        is_reasoning_model = self._is_reasoning_model(model)
+        
+        # 模型兼容性提示(但不强制切换)
+        if enable_thinking and not is_reasoning_model:
+            print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
 
         print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
+        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
 
         # 准备Ollama API请求
         url = f"{self.base_url}/api/chat"
         payload = {
             "model": model,
             "messages": prompt,
-            "stream": False,
-            "options": {
-                "temperature": self.temperature
-            }
+            "stream": stream_mode,
+            "options": self._build_options(kwargs, is_reasoning_model)
         }
 
         try:
-            response = requests.post(
-                url, 
-                json=payload, 
-                timeout=self.timeout,
-                headers={"Content-Type": "application/json"}
-            )
-            response.raise_for_status()
-            
-            result = response.json()
-            return result["message"]["content"]
-            
+            if stream_mode:
+                # 流式处理模式
+                if is_reasoning_model and enable_thinking:
+                    print("使用流式处理模式,启用推理功能")
+                else:
+                    print("使用流式处理模式,常规聊天")
+                
+                return self._handle_stream_response(url, payload, is_reasoning_model and enable_thinking)
+            else:
+                # 非流式处理模式
+                if is_reasoning_model and enable_thinking:
+                    print("使用非流式处理模式,启用推理功能")
+                else:
+                    print("使用非流式处理模式,常规聊天")
+                
+                return self._handle_non_stream_response(url, payload, is_reasoning_model and enable_thinking)
+                
         except requests.exceptions.RequestException as e:
             print(f"Ollama API请求失败: {e}")
             raise Exception(f"Ollama API调用失败: {str(e)}")
 
+    def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
+        """处理流式响应"""
+        response = requests.post(
+            url, 
+            json=payload, 
+            timeout=self.timeout,
+            headers={"Content-Type": "application/json"},
+            stream=True
+        )
+        response.raise_for_status()
+        
+        collected_content = []
+        
+        for line in response.iter_lines():
+            if line:
+                try:
+                    chunk_data = json.loads(line.decode('utf-8'))
+                    
+                    if 'message' in chunk_data and 'content' in chunk_data['message']:
+                        content = chunk_data['message']['content']
+                        collected_content.append(content)
+                    
+                    # 检查是否完成
+                    if chunk_data.get('done', False):
+                        break
+                        
+                except json.JSONDecodeError:
+                    continue
+        
+        # 合并所有内容
+        full_content = "".join(collected_content)
+        
+        # 如果启用推理功能,尝试分离推理内容和最终答案
+        if enable_reasoning:
+            reasoning_content, final_content = self._extract_reasoning(full_content)
+            
+            if reasoning_content:
+                print("Model reasoning process:\n", reasoning_content)
+                return final_content
+        
+        return full_content
+
+    def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
+        """处理非流式响应"""
+        response = requests.post(
+            url, 
+            json=payload, 
+            timeout=self.timeout,
+            headers={"Content-Type": "application/json"}
+        )
+        response.raise_for_status()
+        
+        result = response.json()
+        content = result["message"]["content"]
+        
+        if enable_reasoning:
+            # 尝试分离推理内容和最终答案
+            reasoning_content, final_content = self._extract_reasoning(content)
+            
+            if reasoning_content:
+                print("Model reasoning process:\n", reasoning_content)
+                return final_content
+        
+        return content
+
     def test_connection(self, test_prompt="你好") -> dict:
         """测试Ollama连接"""
         result = {
@@ -67,11 +180,33 @@ class OllamaChat(BaseLLMChat):
             "model": self.model,
             "base_url": self.base_url,
             "message": "",
+            "available_models": [],
+            "ollama_version": None
         }
         
         try:
+            # 检查服务健康状态
+            if not self._check_ollama_health():
+                result["message"] = "Ollama 服务不可用"
+                return result
+            
+            # 获取可用模型列表
+            try:
+                result["available_models"] = self.list_models()
+                
+                # 检查目标模型是否存在
+                if self.model not in result["available_models"]:
+                    print(f"警告:模型 {self.model} 不存在,尝试拉取...")
+                    if not self.pull_model(self.model):
+                        result["message"] = f"模型 {self.model} 不存在且拉取失败"
+                        return result
+            except Exception as e:
+                print(f"获取模型列表失败: {e}")
+                result["available_models"] = [self.model]
+            
             print(f"测试Ollama连接 - 模型: {self.model}")
             print(f"Ollama服务地址: {self.base_url}")
+            print(f"可用模型: {', '.join(result['available_models'])}")
             
             # 测试简单对话
             prompt = [self.user_message(test_prompt)]
@@ -84,4 +219,209 @@ class OllamaChat(BaseLLMChat):
             
         except Exception as e:
             result["message"] = f"Ollama连接测试失败: {str(e)}"
-            return result 
+            return result
+
+    def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
+        """智能确定使用的模型"""
+        # 优先级:运行时参数 > 配置文件 > 智能选择
+        if kwargs.get("model", None) is not None:
+            return kwargs.get("model")
+        elif kwargs.get("engine", None) is not None:
+            return kwargs.get("engine")
+        elif self.config is not None and "engine" in self.config:
+            return self.config["engine"]
+        elif self.config is not None and "model" in self.config:
+            return self.config["model"]
+        else:
+            # 智能选择模型
+            if enable_thinking:
+                # 优先选择推理模型
+                try:
+                    available_models = self.list_models()
+                    reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
+                    if reasoning_models:
+                        return reasoning_models[0]  # 选择第一个推理模型
+                    else:
+                        print("警告:未找到推理模型,使用默认模型")
+                        return self.model
+                except Exception as e:
+                    print(f"获取模型列表时出错: {e},使用默认模型")
+                    return self.model
+            else:
+                # 根据 token 数量选择模型
+                if num_tokens > 8000:
+                    # 长文本,选择长上下文模型
+                    try:
+                        available_models = self.list_models()
+                        long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
+                        if long_context_models:
+                            return long_context_models[0]
+                    except Exception as e:
+                        print(f"获取模型列表时出错: {e},使用默认模型")
+                
+                return self.model
+
+    def _is_reasoning_model(self, model: str) -> bool:
+        """检查是否为推理模型"""
+        reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
+        return any(keyword in model.lower() for keyword in reasoning_keywords)
+
+    def _build_options(self, kwargs: dict, is_reasoning_model: bool) -> dict:
+        """构建 Ollama options 参数"""
+        options = {
+            "temperature": self.temperature,
+            "num_ctx": self.num_ctx,
+            "num_predict": self.num_predict,
+            "repeat_penalty": self.repeat_penalty,
+        }
+
+        # 过滤掉自定义参数,避免传递给 API
+        filtered_kwargs = {k: v for k, v in kwargs.items() 
+                          if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
+
+        # 添加其他参数到 options 中
+        for k, v in filtered_kwargs.items():
+            options[k] = v
+
+        # 推理模型特定参数调整
+        if is_reasoning_model:
+            # 推理模型可能需要更多的预测token
+            if options["num_predict"] == -1:
+                options["num_predict"] = 2048
+            # 降低重复惩罚,允许更多的推理重复
+            options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
+
+        return options
+
+    def _is_reasoning_content(self, content: str) -> bool:
+        """判断内容是否为推理内容"""
+        reasoning_patterns = [
+            r'<think>.*?</think>',
+            r'<reasoning>.*?</reasoning>',
+            r'<analysis>.*?</analysis>',
+            r'思考:',
+            r'分析:',
+            r'推理:'
+        ]
+        return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
+
+    def _extract_reasoning(self, content: str) -> tuple:
+        """提取推理内容和最终答案"""
+        reasoning_patterns = [
+            r'<think>(.*?)</think>',
+            r'<reasoning>(.*?)</reasoning>',
+            r'<analysis>(.*?)</analysis>',
+            r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
+            r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
+            r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
+        ]
+        
+        reasoning_content = ""
+        final_content = content
+        
+        for pattern in reasoning_patterns:
+            matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
+            if matches:
+                reasoning_content = "\n".join(matches)
+                final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
+                break
+        
+        # 如果没有找到明确的推理标记,但内容很长,尝试简单分割
+        if not reasoning_content and len(content) > 500:
+            lines = content.split('\n')
+            if len(lines) > 10:
+                # 假设前半部分是推理,后半部分是答案
+                mid_point = len(lines) // 2
+                potential_reasoning = '\n'.join(lines[:mid_point])
+                potential_answer = '\n'.join(lines[mid_point:])
+                
+                # 简单启发式:如果前半部分包含推理关键词,则分离
+                if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
+                    reasoning_content = potential_reasoning
+                    final_content = potential_answer
+        
+        return reasoning_content, final_content
+
+    # Ollama 独特功能
+    def list_models(self) -> List[str]:
+        """列出可用的模型"""
+        try:
+            response = requests.get(f"{self.base_url}/api/tags", timeout=5)  # 使用较短的超时时间
+            response.raise_for_status()
+            data = response.json()
+            models = [model["name"] for model in data.get("models", [])]
+            return models if models else [self.model]  # 如果没有模型,返回默认模型
+        except requests.exceptions.RequestException as e:
+            print(f"获取模型列表失败: {e}")
+            return [self.model]  # 返回默认模型
+        except Exception as e:
+            print(f"解析模型列表失败: {e}")
+            return [self.model]  # 返回默认模型
+
+    def pull_model(self, model_name: str) -> bool:
+        """拉取模型"""
+        try:
+            print(f"正在拉取模型: {model_name}")
+            response = requests.post(
+                f"{self.base_url}/api/pull",
+                json={"name": model_name},
+                timeout=300  # 拉取模型可能需要较长时间
+            )
+            response.raise_for_status()
+            print(f"✅ 模型 {model_name} 拉取成功")
+            return True
+        except requests.exceptions.RequestException as e:
+            print(f"❌ 模型 {model_name} 拉取失败: {e}")
+            return False
+
+    def delete_model(self, model_name: str) -> bool:
+        """删除模型"""
+        try:
+            response = requests.delete(
+                f"{self.base_url}/api/delete",
+                json={"name": model_name},
+                timeout=self.timeout
+            )
+            response.raise_for_status()
+            print(f"✅ 模型 {model_name} 删除成功")
+            return True
+        except requests.exceptions.RequestException as e:
+            print(f"❌ 模型 {model_name} 删除失败: {e}")
+            return False
+
+    def get_model_info(self, model_name: str) -> Optional[Dict]:
+        """获取模型信息"""
+        try:
+            response = requests.post(
+                f"{self.base_url}/api/show",
+                json={"name": model_name},
+                timeout=self.timeout
+            )
+            response.raise_for_status()
+            return response.json()
+        except requests.exceptions.RequestException as e:
+            print(f"获取模型信息失败: {e}")
+            return None
+
+    def get_system_info(self) -> Dict:
+        """获取 Ollama 系统信息"""
+        try:
+            # 获取版本信息
+            version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
+            version_info = version_response.json() if version_response.status_code == 200 else {}
+            
+            # 获取模型列表
+            models = self.list_models()
+            
+            return {
+                "base_url": self.base_url,
+                "version": version_info.get("version", "unknown"),
+                "available_models": models,
+                "current_model": self.model,
+                "timeout": self.timeout,
+                "num_ctx": self.num_ctx,
+                "num_predict": self.num_predict,
+                "repeat_penalty": self.repeat_penalty
+            }
+        except Exception as e:
+            return {"error": str(e)} 

+ 40 - 14
customllm/qianwen_chat.py

@@ -58,18 +58,37 @@ class QianWenChat(BaseLLMChat):
         # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
         enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
         
+        # 从配置和参数中获取stream设置
+        # 优先级:运行时参数 > 配置文件 > 默认值(False)
+        stream_mode = kwargs.get("stream", self.config.get("stream", False))
+        
+        # 千问API约束:enable_thinking=True时必须stream=True
+        # 如果stream=False但enable_thinking=True,则忽略enable_thinking
+        if enable_thinking and not stream_mode:
+            print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
+            enable_thinking = False
+        
+        # 创建一个干净的kwargs副本,移除可能导致API错误的自定义参数
+        # 注意:enable_thinking和stream是千问API的有效参数,需要正确传递
+        filtered_kwargs = {k: v for k, v in kwargs.items() 
+                          if k not in ['model', 'engine']}  # 只移除model和engine
+        
         # 公共参数
         common_params = {
             "messages": prompt,
             "stop": None,
             "temperature": self.temperature,
+            "stream": stream_mode,  # 明确设置stream参数
         }
         
-        # 如果启用了thinking,则使用流式处理,但不直接传递enable_thinking参数
+        # 千问OpenAI兼容接口要求enable_thinking参数放在extra_body中
         if enable_thinking:
-            common_params["stream"] = True
-            # 千问API不接受enable_thinking作为参数,可能需要通过header或其他方式传递
-            # 也可能它只是默认启用stream=True时的thinking功能
+            common_params["extra_body"] = {"enable_thinking": True}
+        
+        # 传递其他过滤后的参数(排除enable_thinking,因为我们已经单独处理了)
+        for k, v in filtered_kwargs.items():
+            if k not in ['enable_thinking', 'stream']:  # 避免重复设置
+                common_params[k] = v
         
         model = None
         # 确定使用的模型
@@ -94,12 +113,15 @@ class QianWenChat(BaseLLMChat):
             common_params["model"] = model
         
         print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
+        print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
         
-        if enable_thinking:
+        if stream_mode:
             # 流式处理模式
-            print("使用流式处理模式,启用thinking功能")
+            if enable_thinking:
+                print("使用流式处理模式,启用thinking功能")
+            else:
+                print("使用流式处理模式,不启用thinking功能")
             
-            # 检查是否需要通过headers传递enable_thinking参数
             response_stream = self.client.chat.completions.create(**common_params)
             
             # 收集流式响应
@@ -107,17 +129,21 @@ class QianWenChat(BaseLLMChat):
             collected_content = []
             
             for chunk in response_stream:
-                # 处理thinking部分
-                if hasattr(chunk, 'thinking') and chunk.thinking:
-                    collected_thinking.append(chunk.thinking)
+                # 处理thinking部分(仅当enable_thinking=True时)
+                if enable_thinking and hasattr(chunk, 'choices') and chunk.choices:
+                    delta = chunk.choices[0].delta
+                    if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
+                        collected_thinking.append(delta.reasoning_content)
                 
                 # 处理content部分
-                if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
-                    collected_content.append(chunk.choices[0].delta.content)
+                if hasattr(chunk, 'choices') and chunk.choices:
+                    delta = chunk.choices[0].delta
+                    if hasattr(delta, 'content') and delta.content:
+                        collected_content.append(delta.content)
             
             # 可以在这里处理thinking的展示逻辑,如保存到日志等
-            if collected_thinking:
-                print("Model thinking process:", "".join(collected_thinking))
+            if enable_thinking and collected_thinking:
+                print("Model thinking process:\n", "".join(collected_thinking))
             
             # 返回完整的内容
             return "".join(collected_content)

+ 1 - 1
docs/ollama 集成方案.md

@@ -421,7 +421,7 @@ from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
 from customdeepseek.custom_deepseek_chat import DeepSeekChat
 from customollama.ollama_chat import OllamaChat  # 新增
 import app_config 
-from embedding_function import get_embedding_function
+from core.embedding_function import get_embedding_function
 import os
 
 class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):

+ 1 - 1
docs/ollama_integration_guide.md

@@ -152,7 +152,7 @@ print(model_info)
 ### 2. 创建Vanna实例
 
 ```python
-from vanna_llm_factory import create_vanna_instance
+from core.vanna_llm_factory import create_vanna_instance
 
 # 根据配置自动创建合适的实例
 vn = create_vanna_instance()

+ 18 - 19
flask_app.py

@@ -1,24 +1,23 @@
 # 给dataops 对话助手返回结果
-from vanna.flask import VannaFlaskApp
-from vanna_llm_factory import create_vanna_instance
-from flask import request, jsonify
-import pandas as pd
+from flask import Flask, jsonify, request
+from core.vanna_llm_factory import create_vanna_instance
 
+app = Flask(__name__)
 vn = create_vanna_instance()
 
-# 实例化 VannaFlaskApp
-app = VannaFlaskApp(
-    vn,
-    title="辞图智能数据问答平台",
-    logo = "https://www.citupro.com/img/logo-black-2.png",
-    subtitle="让 AI 为你写 SQL",
-    chart=True,
-    allow_llm_to_see_data=True,
-    ask_results_correct=True,
-    followup_questions=True,
-    debug=True
-)
+@app.route('/ask', methods=['POST'])
+def ask_endpoint():
+    try:
+        data = request.json
+        question = data.get('question', '')
+        if not question:
+            return jsonify({"error": "Question is required"}), 400
+        
+        # 获取SQL答案
+        result = vn.ask(question)
+        return jsonify({"result": result})
+    except Exception as e:
+        return jsonify({"error": str(e)}), 500
 
-
-print("正在启动Flask应用...")
-app.run(host="0.0.0.0", port=8084, debug=True)
+if __name__ == '__main__':
+    app.run(debug=True, port=5000)

+ 2 - 2
training/run_training.py

@@ -26,7 +26,7 @@ def check_embedding_model_connection():
     Returns:
         bool: 连接成功返回True,否则终止程序
     """
-    from embedding_function import test_embedding_connection
+    from core.embedding_function import test_embedding_connection
 
     print("正在检查嵌入模型连接...")
     
@@ -559,7 +559,7 @@ def main():
         
         # 验证数据是否成功写入
         print("\n===== 验证训练数据 =====")
-        from vanna_llm_factory import create_vanna_instance
+        from core.vanna_llm_factory import create_vanna_instance
         vn = create_vanna_instance()
         
         # 根据向量数据库类型执行不同的验证逻辑

+ 1 - 1
training/vanna_trainer.py

@@ -16,7 +16,7 @@ import app_config
 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
 # 创建vanna实例
-from vanna_llm_factory import create_vanna_instance
+from core.vanna_llm_factory import create_vanna_instance
 
 vn = create_vanna_instance()