Explorar o código

增加了对ollama大模型的支持,修改了app_config.py的参数名称。

wangxq hai 3 semanas
pai
achega
f1b9d8919e

+ 39 - 8
app_config.py

@@ -4,13 +4,23 @@ import os
 # 加载.env文件中的环境变量
 load_dotenv()
 
-# 使用的模型类型("qwen" 或 "deepseek")
-LLM_MODEL_NAME = "qwen"
-# 向量数据库类型, chromadb 或 pgvector
-VECTOR_DB_NAME = "pgvector"
+# ===== 模型提供商类型配置 =====
+# LLM模型提供商类型:api 或 ollama
+LLM_MODEL_TYPE = "ollama"  # api, ollama
 
+# Embedding模型提供商类型:api 或 ollama  
+EMBEDDING_MODEL_TYPE = "ollama"  # api, ollama
+
+# ===== 模型名称配置 =====
+# API LLM模型名称(当LLM_MODEL_TYPE="api"时使用:qwen 或 deepseek)
+API_LLM_MODEL = "qwen"
+
+# 向量数据库类型:chromadb 或 pgvector
+VECTOR_DB_TYPE = "pgvector"
+
+# ===== API LLM模型配置 =====
 # DeepSeek模型配置
-DEEPSEEK_CONFIG = {
+API_DEEPSEEK_CONFIG = {
     "api_key": os.getenv("DEEPSEEK_API_KEY"),  # 从环境变量读取API密钥
     "model": "deepseek-reasoner",  # deepseek-chat, deepseek-reasoner
     "allow_llm_to_see_data": True,
@@ -20,9 +30,8 @@ DEEPSEEK_CONFIG = {
     "enable_thinking": False  # 自定义,是否支持流模式
 }
 
-
 # Qwen模型配置
-QWEN_CONFIG = {
+API_QWEN_CONFIG = {
     "api_key": os.getenv("QWEN_API_KEY"),  # 从环境变量读取API密钥
     "model": "qwen-plus",
     "allow_llm_to_see_data": True,
@@ -36,7 +45,8 @@ QWEN_CONFIG = {
 #qwen-plus-latest
 #qwen-plus
 
-EMBEDDING_CONFIG = {
+# ===== API Embedding模型配置 =====
+API_EMBEDDING_CONFIG = {
     "model_name": "BAAI/bge-m3",
     "api_key": os.getenv("EMBEDDING_API_KEY"),
     "base_url": os.getenv("EMBEDDING_BASE_URL"),
@@ -44,6 +54,27 @@ EMBEDDING_CONFIG = {
 }
 
 
+# ===== Ollama LLM模型配置 =====
+OLLAMA_LLM_CONFIG = {
+    "base_url": "http://192.168.3.204:11434",  # Ollama服务地址
+    "model": "qwen3:32b",  # Ollama模型名称,如:qwen3:32b, deepseek-r1:32b
+    "allow_llm_to_see_data": True,
+    "temperature": 0.7,
+    "n_results": 6,
+    "language": "Chinese",
+    "timeout": 60  # Ollama可能需要更长超时时间
+}
+
+
+# ===== Ollama Embedding模型配置 =====
+OLLAMA_EMBEDDING_CONFIG = {
+    "base_url": "http://192.168.3.204:11434",  # Ollama服务地址
+    "model_name": "bge-m3:567m",  # Ollama embedding模型名称
+    "embedding_dimension": 1024  # 根据实际模型调整
+}
+
+
+
 # 应用数据库连接配置 (业务数据库)
 APP_DB_CONFIG = {
     "host": "192.168.67.1",

+ 13 - 0
citu_app.py

@@ -204,6 +204,19 @@ def ask_cached():
 
 @app.flask_app.route('/api/v1/citu_train_question_sql', methods=['POST'])
 def citu_train_question_sql():
+    """
+    训练问题-SQL对接口
+    
+    此API将接收的question/sql pair写入到training库中,用于训练和改进AI模型。
+    支持仅传入SQL或同时传入问题和SQL进行训练。
+    
+    Args:
+        question (str, optional): 用户问题
+        sql (str, required): 对应的SQL查询语句
+    
+    Returns:
+        JSON: 包含训练ID和成功消息的响应
+    """
     try:
         req = request.get_json(force=True)
         question = req.get('question')

+ 191 - 0
common/utils.py

@@ -0,0 +1,191 @@
+"""
+配置相关的工具函数
+用于处理不同模型类型的配置选择逻辑
+"""
+
+def get_current_embedding_config():
+    """
+    根据EMBEDDING_MODEL_TYPE返回当前应该使用的embedding配置
+    
+    Returns:
+        dict: 当前的embedding配置字典
+        
+    Raises:
+        ImportError: 如果无法导入app_config
+        ValueError: 如果EMBEDDING_MODEL_TYPE值无效
+    """
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py,请确保该文件存在")
+    
+    if app_config.EMBEDDING_MODEL_TYPE == "ollama":
+        return app_config.OLLAMA_EMBEDDING_CONFIG
+    elif app_config.EMBEDDING_MODEL_TYPE == "api":
+        return app_config.API_EMBEDDING_CONFIG
+    else:
+        raise ValueError(f"不支持的EMBEDDING_MODEL_TYPE: {app_config.EMBEDDING_MODEL_TYPE}")
+
+
+def get_current_llm_config():
+    """
+    根据LLM_MODEL_TYPE和API_LLM_MODEL返回当前应该使用的LLM配置
+    
+    Returns:
+        dict: 当前的LLM配置字典
+        
+    Raises:
+        ImportError: 如果无法导入app_config
+        ValueError: 如果LLM_MODEL_TYPE或API_LLM_MODEL值无效
+    """
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py,请确保该文件存在")
+    
+    if app_config.LLM_MODEL_TYPE == "ollama":
+        return app_config.OLLAMA_LLM_CONFIG
+    elif app_config.LLM_MODEL_TYPE == "api":
+        if app_config.API_LLM_MODEL == "qwen":
+            return app_config.API_QWEN_CONFIG
+        elif app_config.API_LLM_MODEL == "deepseek":
+            return app_config.API_DEEPSEEK_CONFIG
+        else:
+            raise ValueError(f"不支持的API_LLM_MODEL: {app_config.API_LLM_MODEL}")
+    else:
+        raise ValueError(f"不支持的LLM_MODEL_TYPE: {app_config.LLM_MODEL_TYPE}")
+
+
+def get_current_vector_db_config():
+    """
+    根据VECTOR_DB_TYPE返回当前应该使用的向量数据库配置
+    
+    Returns:
+        dict: 当前的向量数据库配置字典
+        
+    Raises:
+        ImportError: 如果无法导入app_config
+        ValueError: 如果VECTOR_DB_TYPE值无效
+    """
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py,请确保该文件存在")
+    
+    if app_config.VECTOR_DB_TYPE == "pgvector":
+        return app_config.PGVECTOR_CONFIG
+    elif app_config.VECTOR_DB_TYPE == "chromadb":
+        # ChromaDB通常不需要复杂配置,返回项目根目录路径
+        import os
+        return {"path": os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}
+    else:
+        raise ValueError(f"不支持的VECTOR_DB_TYPE: {app_config.VECTOR_DB_TYPE}")
+
+
+def get_current_model_info():
+    """
+    获取当前配置的模型信息摘要
+    
+    Returns:
+        dict: 包含当前所有模型配置信息的字典
+        
+    Raises:
+        ImportError: 如果无法导入app_config
+    """
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py,请确保该文件存在")
+    
+    # 获取LLM模型名称
+    if app_config.LLM_MODEL_TYPE == "ollama":
+        llm_model_name = app_config.OLLAMA_LLM_CONFIG.get("model", "unknown")
+    else:
+        llm_model_name = app_config.API_LLM_MODEL
+    
+    # 获取Embedding模型名称
+    if app_config.EMBEDDING_MODEL_TYPE == "ollama":
+        embedding_model_name = app_config.OLLAMA_EMBEDDING_CONFIG.get("model_name", "unknown")
+    else:
+        embedding_model_name = app_config.API_EMBEDDING_CONFIG.get("model_name", "unknown")
+    
+    return {
+        "llm_type": app_config.LLM_MODEL_TYPE,
+        "llm_model": llm_model_name,
+        "embedding_type": app_config.EMBEDDING_MODEL_TYPE,
+        "embedding_model": embedding_model_name,
+        "vector_db": app_config.VECTOR_DB_TYPE
+    }
+
+
+def is_using_ollama_llm():
+    """
+    检查当前是否使用Ollama作为LLM提供商
+    
+    Returns:
+        bool: 如果使用Ollama LLM返回True,否则返回False
+    """
+    try:
+        import app_config
+        return app_config.LLM_MODEL_TYPE == "ollama"
+    except ImportError:
+        return False
+
+
+def is_using_ollama_embedding():
+    """
+    检查当前是否使用Ollama作为Embedding提供商
+    
+    Returns:
+        bool: 如果使用Ollama Embedding返回True,否则返回False
+    """
+    try:
+        import app_config
+        return app_config.EMBEDDING_MODEL_TYPE == "ollama"
+    except ImportError:
+        return False
+
+
+def is_using_api_llm():
+    """
+    检查当前是否使用API作为LLM提供商
+    
+    Returns:
+        bool: 如果使用API LLM返回True,否则返回False
+    """
+    try:
+        import app_config
+        return app_config.LLM_MODEL_TYPE == "api"
+    except ImportError:
+        return False
+
+
+def is_using_api_embedding():
+    """
+    检查当前是否使用API作为Embedding提供商
+    
+    Returns:
+        bool: 如果使用API Embedding返回True,否则返回False
+    """
+    try:
+        import app_config
+        return app_config.EMBEDDING_MODEL_TYPE == "api"
+    except ImportError:
+        return False
+
+
+def print_current_config():
+    """
+    打印当前的配置信息,用于调试和确认配置
+    """
+    try:
+        model_info = get_current_model_info()
+        print("=== 当前模型配置 ===")
+        print(f"LLM提供商: {model_info['llm_type']}")
+        print(f"LLM模型: {model_info['llm_model']}")
+        print(f"Embedding提供商: {model_info['embedding_type']}")
+        print(f"Embedding模型: {model_info['embedding_model']}")
+        print(f"向量数据库: {model_info['vector_db']}")
+        print("==================")
+    except Exception as e:
+        print(f"无法获取配置信息: {e}") 

+ 190 - 0
common/vanna_combinations.py

@@ -0,0 +1,190 @@
+"""
+Vanna LLM与向量数据库的组合类
+统一管理所有LLM提供商与向量数据库的组合
+"""
+
+# 向量数据库导入
+from vanna.chromadb import ChromaDB_VectorStore
+try:
+    from custompgvector import PG_VectorStore
+except ImportError:
+    print("警告: 无法导入 PG_VectorStore,PGVector相关组合类将不可用")
+    PG_VectorStore = None
+
+# LLM提供商导入
+from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
+from customdeepseek.custom_deepseek_chat import DeepSeekChat
+try:
+    from customollama.ollama_chat import OllamaChat
+except ImportError:
+    print("警告: 无法导入 OllamaChat,Ollama相关组合类将不可用")
+    OllamaChat = None
+
+
+# ===== API LLM + ChromaDB 组合 =====
+
+class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
+    """Qwen LLM + ChromaDB 向量数据库组合"""
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        QianWenAI_Chat.__init__(self, config=config)
+
+
+class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
+    """DeepSeek LLM + ChromaDB 向量数据库组合"""
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        DeepSeekChat.__init__(self, config=config)
+
+
+# ===== API LLM + PGVector 组合 =====
+
+if PG_VectorStore is not None:
+    class Vanna_Qwen_PGVector(PG_VectorStore, QianWenAI_Chat):
+        """Qwen LLM + PGVector 向量数据库组合"""
+        def __init__(self, config=None):
+            PG_VectorStore.__init__(self, config=config)
+            QianWenAI_Chat.__init__(self, config=config)
+
+    class Vanna_DeepSeek_PGVector(PG_VectorStore, DeepSeekChat):
+        """DeepSeek LLM + PGVector 向量数据库组合"""
+        def __init__(self, config=None):
+            PG_VectorStore.__init__(self, config=config)
+            DeepSeekChat.__init__(self, config=config)
+else:
+    # 如果PG_VectorStore不可用,创建占位符类
+    class Vanna_Qwen_PGVector:
+        def __init__(self, config=None):
+            raise ImportError("PG_VectorStore 不可用,无法创建 Vanna_Qwen_PGVector 实例")
+    
+    class Vanna_DeepSeek_PGVector:
+        def __init__(self, config=None):
+            raise ImportError("PG_VectorStore 不可用,无法创建 Vanna_DeepSeek_PGVector 实例")
+
+
+# ===== Ollama LLM + ChromaDB 组合 =====
+
+if OllamaChat is not None:
+    class Vanna_Ollama_ChromaDB(ChromaDB_VectorStore, OllamaChat):
+        """Ollama LLM + ChromaDB 向量数据库组合"""
+        def __init__(self, config=None):
+            ChromaDB_VectorStore.__init__(self, config=config)
+            OllamaChat.__init__(self, config=config)
+else:
+    class Vanna_Ollama_ChromaDB:
+        def __init__(self, config=None):
+            raise ImportError("OllamaChat 不可用,无法创建 Vanna_Ollama_ChromaDB 实例")
+
+
+# ===== Ollama LLM + PGVector 组合 =====
+
+if OllamaChat is not None and PG_VectorStore is not None:
+    class Vanna_Ollama_PGVector(PG_VectorStore, OllamaChat):
+        """Ollama LLM + PGVector 向量数据库组合"""
+        def __init__(self, config=None):
+            PG_VectorStore.__init__(self, config=config)
+            OllamaChat.__init__(self, config=config)
+else:
+    class Vanna_Ollama_PGVector:
+        def __init__(self, config=None):
+            error_msg = []
+            if OllamaChat is None:
+                error_msg.append("OllamaChat 不可用")
+            if PG_VectorStore is None:
+                error_msg.append("PG_VectorStore 不可用")
+            raise ImportError(f"{', '.join(error_msg)},无法创建 Vanna_Ollama_PGVector 实例")
+
+
+# ===== 组合类映射表 =====
+
+# LLM类型到类名的映射
+LLM_CLASS_MAP = {
+    "qwen": {
+        "chromadb": Vanna_Qwen_ChromaDB,
+        "pgvector": Vanna_Qwen_PGVector,
+    },
+    "deepseek": {
+        "chromadb": Vanna_DeepSeek_ChromaDB,
+        "pgvector": Vanna_DeepSeek_PGVector,
+    },
+    "ollama": {
+        "chromadb": Vanna_Ollama_ChromaDB,
+        "pgvector": Vanna_Ollama_PGVector,
+    }
+}
+
+
+def get_vanna_class(llm_type: str, vector_db_type: str):
+    """
+    根据LLM类型和向量数据库类型获取对应的Vanna组合类
+    
+    Args:
+        llm_type: LLM类型 ("qwen", "deepseek", "ollama")
+        vector_db_type: 向量数据库类型 ("chromadb", "pgvector")
+        
+    Returns:
+        对应的Vanna组合类
+        
+    Raises:
+        ValueError: 如果不支持的组合类型
+    """
+    llm_type = llm_type.lower()
+    vector_db_type = vector_db_type.lower()
+    
+    if llm_type not in LLM_CLASS_MAP:
+        raise ValueError(f"不支持的LLM类型: {llm_type},支持的类型: {list(LLM_CLASS_MAP.keys())}")
+    
+    if vector_db_type not in LLM_CLASS_MAP[llm_type]:
+        raise ValueError(f"不支持的向量数据库类型: {vector_db_type},支持的类型: {list(LLM_CLASS_MAP[llm_type].keys())}")
+    
+    return LLM_CLASS_MAP[llm_type][vector_db_type]
+
+
+def list_available_combinations():
+    """
+    列出所有可用的LLM与向量数据库组合
+    
+    Returns:
+        dict: 可用组合的字典
+    """
+    available = {}
+    
+    for llm_type, vector_dbs in LLM_CLASS_MAP.items():
+        available[llm_type] = []
+        for vector_db_type, cls in vector_dbs.items():
+            try:
+                # 尝试创建实例来检查是否可用
+                cls(config={})
+                available[llm_type].append(vector_db_type)
+            except ImportError:
+                # 如果导入错误,说明不可用
+                continue
+            except Exception:
+                # 其他错误(如配置错误)仍然认为是可用的
+                available[llm_type].append(vector_db_type)
+    
+    return available
+
+
+def print_available_combinations():
+    """打印所有可用的组合"""
+    print("可用的LLM与向量数据库组合:")
+    print("=" * 40)
+    
+    combinations = list_available_combinations()
+    
+    for llm_type, vector_dbs in combinations.items():
+        print(f"\n{llm_type.upper()} LLM:")
+        for vector_db in vector_dbs:
+            class_name = LLM_CLASS_MAP[llm_type][vector_db].__name__
+            print(f"  + {vector_db} -> {class_name}")
+    
+    if not any(combinations.values()):
+        print("没有可用的组合,请检查依赖是否正确安装")
+
+
+# ===== 向后兼容性支持 =====
+
+# 为了保持向后兼容,可以在这里添加别名
+# 例如:
+# VannaQwenChromaDB = Vanna_Qwen_ChromaDB  # 旧的命名风格 

+ 2 - 0
customollama/__init__.py

@@ -0,0 +1,2 @@
+from .ollama_chat import OllamaChat
+from .ollama_embedding import OllamaEmbeddingFunction 

+ 165 - 0
customollama/ollama_chat.py

@@ -0,0 +1,165 @@
+import requests
+import json
+from vanna.base import VannaBase
+from typing import List, Dict, Any
+
+class OllamaChat(VannaBase):
+    def __init__(self, config=None):
+        print("...OllamaChat init...")
+        VannaBase.__init__(self, config=config)
+
+        print("传入的 config 参数如下:")
+        for key, value in self.config.items():
+            print(f"  {key}: {value}")
+
+        # 默认参数
+        self.temperature = 0.7
+        self.base_url = config.get("base_url", "http://localhost:11434")
+        self.model = config.get("model", "qwen2.5:7b")
+        self.timeout = config.get("timeout", 60)
+
+        if "temperature" in config:
+            print(f"temperature is changed to: {config['temperature']}")
+            self.temperature = config["temperature"]
+
+    def system_message(self, message: str) -> any:
+        print(f"system_content: {message}")
+        return {"role": "system", "content": message}
+
+    def user_message(self, message: str) -> any:
+        print(f"\nuser_content: {message}")
+        return {"role": "user", "content": message}
+
+    def assistant_message(self, message: str) -> any:
+        print(f"assistant_content: {message}")
+        return {"role": "assistant", "content": message}
+
+    def submit_prompt(self, prompt, **kwargs) -> str:
+        if prompt is None:
+            raise Exception("Prompt is None")
+
+        if len(prompt) == 0:
+            raise Exception("Prompt is empty")
+
+        # 计算token数量估计
+        num_tokens = 0
+        for message in prompt:
+            num_tokens += len(message["content"]) / 4
+
+        # 确定使用的模型
+        model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
+
+        print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
+
+        # 准备Ollama API请求
+        url = f"{self.base_url}/api/chat"
+        payload = {
+            "model": model,
+            "messages": prompt,
+            "stream": False,
+            "options": {
+                "temperature": self.temperature
+            }
+        }
+
+        try:
+            response = requests.post(
+                url, 
+                json=payload, 
+                timeout=self.timeout,
+                headers={"Content-Type": "application/json"}
+            )
+            response.raise_for_status()
+            
+            result = response.json()
+            return result["message"]["content"]
+            
+        except requests.exceptions.RequestException as e:
+            print(f"Ollama API请求失败: {e}")
+            raise Exception(f"Ollama API调用失败: {str(e)}")
+
+    def generate_sql(self, question: str, **kwargs) -> str:
+        """重写generate_sql方法,增加异常处理"""
+        try:
+            print(f"[DEBUG] 尝试为问题生成SQL: {question}")
+            sql = super().generate_sql(question, **kwargs)
+            
+            if not sql or sql.strip() == "":
+                print(f"[WARNING] 生成的SQL为空")
+                return None
+            
+            # 检查返回内容是否为有效SQL
+            sql_lower = sql.lower().strip()
+            error_indicators = [
+                "insufficient context", "无法生成", "sorry", "cannot", "不能",
+                "no relevant", "no suitable", "unable to", "无法", "抱歉"
+            ]
+            
+            for indicator in error_indicators:
+                if indicator in sql_lower:
+                    print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
+                    return None
+            
+            sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
+            if not any(keyword in sql_lower for keyword in sql_keywords):
+                print(f"[WARNING] 返回内容不像有效SQL: {sql}")
+                return None
+                
+            print(f"[SUCCESS] 成功生成SQL: {sql}")
+            return sql
+            
+        except Exception as e:
+            print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
+            return None
+
+    def generate_question(self, sql: str, **kwargs) -> str:
+        """根据SQL生成中文问题"""
+        prompt = [
+            self.system_message(
+                "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
+            ),
+            self.user_message(sql)
+        ]
+        response = self.submit_prompt(prompt, **kwargs)
+        return response
+    
+    def chat_with_llm(self, question: str, **kwargs) -> str:
+        """直接与LLM对话"""
+        try:
+            prompt = [
+                self.system_message(
+                    "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
+                ),
+                self.user_message(question)
+            ]
+            response = self.submit_prompt(prompt, **kwargs)
+            return response
+        except Exception as e:
+            print(f"[ERROR] LLM对话失败: {str(e)}")
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+
+    def test_connection(self, test_prompt="你好") -> dict:
+        """测试Ollama连接"""
+        result = {
+            "success": False,
+            "model": self.model,
+            "base_url": self.base_url,
+            "message": "",
+        }
+        
+        try:
+            print(f"测试Ollama连接 - 模型: {self.model}")
+            print(f"Ollama服务地址: {self.base_url}")
+            
+            # 测试简单对话
+            prompt = [self.user_message(test_prompt)]
+            response = self.submit_prompt(prompt)
+            
+            result["success"] = True
+            result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
+            
+            return result
+            
+        except Exception as e:
+            result["message"] = f"Ollama连接测试失败: {str(e)}"
+            return result 

+ 141 - 0
customollama/ollama_embedding.py

@@ -0,0 +1,141 @@
+import requests
+import time
+import numpy as np
+from typing import List, Callable
+
+class OllamaEmbeddingFunction:
+    def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
+        self.model_name = model_name
+        self.base_url = base_url
+        self.embedding_dimension = embedding_dimension
+        self.max_retries = 2
+        self.retry_interval = 2
+
+    def __call__(self, input) -> List[List[float]]:
+        """为文本列表生成嵌入向量"""
+        if not isinstance(input, list):
+            input = [input]
+            
+        embeddings = []
+        for text in input:
+            try:
+                embedding = self.generate_embedding(text)
+                embeddings.append(embedding)
+            except Exception as e:
+                print(f"获取embedding时出错: {e}")
+                embeddings.append([0.0] * self.embedding_dimension)
+                
+        return embeddings
+    
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """为文档列表生成嵌入向量(兼容ChromaDB接口)"""
+        return self.__call__(texts)
+    
+    def embed_query(self, text: str) -> List[float]:
+        """为单个查询文本生成嵌入向量(兼容ChromaDB接口)"""
+        return self.generate_embedding(text)
+    
+    def generate_embedding(self, text: str) -> List[float]:
+        """为单个文本生成嵌入向量"""
+        print(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
+        
+        if not text or len(text.strip()) == 0:
+            print("输入文本为空,返回零向量")
+            return [0.0] * self.embedding_dimension
+
+        url = f"{self.base_url}/api/embeddings"
+        payload = {
+            "model": self.model_name,
+            "prompt": text
+        }
+        
+        retries = 0
+        while retries <= self.max_retries:
+            try:
+                response = requests.post(
+                    url, 
+                    json=payload,
+                    timeout=30
+                )
+                
+                if response.status_code != 200:
+                    error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
+                    print(error_msg)
+                    
+                    if response.status_code in (429, 500, 502, 503, 504):
+                        retries += 1
+                        if retries <= self.max_retries:
+                            wait_time = self.retry_interval * (2 ** (retries - 1))
+                            print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            time.sleep(wait_time)
+                            continue
+                    
+                    raise ValueError(error_msg)
+                
+                result = response.json()
+                
+                if "embedding" in result:
+                    vector = result["embedding"]
+                    
+                    # 验证向量维度
+                    actual_dim = len(vector)
+                    if actual_dim != self.embedding_dimension:
+                        print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                        # 如果维度不匹配,可以选择截断或填充
+                        if actual_dim > self.embedding_dimension:
+                            vector = vector[:self.embedding_dimension]
+                        else:
+                            vector.extend([0.0] * (self.embedding_dimension - actual_dim))
+                    
+                    print(f"成功生成Ollama embedding向量,维度: {len(vector)}")
+                    return vector
+                else:
+                    error_msg = f"Ollama API返回格式异常: {result}"
+                    print(error_msg)
+                    raise ValueError(error_msg)
+                
+            except Exception as e:
+                print(f"生成Ollama embedding时出错: {str(e)}")
+                retries += 1
+                
+                if retries <= self.max_retries:
+                    wait_time = self.retry_interval * (2 ** (retries - 1))
+                    print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    time.sleep(wait_time)
+                else:
+                    print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
+                    return [0.0] * self.embedding_dimension
+        
+        raise RuntimeError("生成Ollama embedding失败")
+
+    def test_connection(self, test_text="测试文本") -> dict:
+        """测试Ollama嵌入模型的连接"""
+        result = {
+            "success": False,
+            "model": self.model_name,
+            "base_url": self.base_url,
+            "message": "",
+            "actual_dimension": None,
+            "expected_dimension": self.embedding_dimension
+        }
+        
+        try:
+            print(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
+            print(f"Ollama服务地址: {self.base_url}")
+            
+            vector = self.generate_embedding(test_text)
+            actual_dimension = len(vector)
+            
+            result["success"] = True
+            result["actual_dimension"] = actual_dimension
+            
+            if actual_dimension != self.embedding_dimension:
+                result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
+            else:
+                result["message"] = f"Ollama连接测试成功,向量维度: {actual_dimension}"
+                
+            return result
+            
+        except Exception as e:
+            result["message"] = f"Ollama连接测试失败: {str(e)}"
+            return result 

+ 127 - 0
docs/config_refactor_summary.md

@@ -0,0 +1,127 @@
+# 配置重构总结
+
+## 重构内容
+
+本次重构对 `app_config.py` 中的配置参数名称进行了标准化,使配置命名更加清晰和一致。
+
+### 重构的配置参数
+
+| 旧配置名称 | 新配置名称 | 说明 |
+|-----------|-----------|------|
+| `DEEPSEEK_CONFIG` | `API_DEEPSEEK_CONFIG` | DeepSeek API模型配置 |
+| `QWEN_CONFIG` | `API_QWEN_CONFIG` | Qwen API模型配置 |
+| `EMBEDDING_OLLAMA_CONFIG` | `OLLAMA_EMBEDDING_CONFIG` | Ollama Embedding模型配置 |
+| `LLM_MODEL_NAME` | `API_LLM_MODEL` | API LLM模型名称配置 |
+| `VECTOR_DB_NAME` | `VECTOR_DB_TYPE` | 向量数据库类型配置 |
+| `EMBEDDING_API_CONFIG` | `API_EMBEDDING_CONFIG` | API Embedding模型配置 |
+
+## 修改的文件
+
+### 1. `app_config.py`
+- 将 `DEEPSEEK_CONFIG` 重命名为 `API_DEEPSEEK_CONFIG`
+- 将 `QWEN_CONFIG` 重命名为 `API_QWEN_CONFIG`
+- 将 `EMBEDDING_OLLAMA_CONFIG` 重命名为 `OLLAMA_EMBEDDING_CONFIG`
+- 将 `LLM_MODEL_NAME` 重命名为 `API_LLM_MODEL`
+- 将 `VECTOR_DB_NAME` 重命名为 `VECTOR_DB_TYPE`
+- 将 `EMBEDDING_API_CONFIG` 重命名为 `API_EMBEDDING_CONFIG`
+
+### 2. `common/utils.py`
+- 更新 `get_current_llm_config()` 函数中的配置引用
+- 更新 `get_current_embedding_config()` 函数中的配置引用
+- 更新 `get_current_vector_db_config()` 函数中的配置引用
+- 更新 `get_current_model_info()` 函数中的配置引用
+
+### 3. 训练脚本
+- `training/run_training.py` - 更新向量数据库配置引用和embedding配置引用
+- `training/vanna_trainer.py` - 更新embedding配置引用
+
+### 4. 文档文件
+- `docs/ollama 集成方案.md` - 更新配置引用
+- `docs/ollama_integration_guide.md` - 更新配置引用和示例
+- `docs/training_integration_fixes.md` - 更新配置引用
+
+## 重构的好处
+
+### 1. 命名一致性
+- API模型配置统一使用 `API_` 前缀
+- Ollama模型配置统一使用 `OLLAMA_` 前缀
+- 配置名称更清晰地表明了提供商类型
+
+### 2. 更好的可读性
+- 配置名称现在明确指示了模型提供商类型
+- 便于理解和维护
+
+### 3. 向后兼容性
+- 旧的配置名称已完全移除
+- 所有引用都已更新到新的配置名称
+
+## 验证结果
+
+运行 `python test_config_refactor.py` 的测试结果:
+
+```
+=== 配置重构测试 ===
+✓ app_config 导入成功
+
+--- 新配置检查 ---
+✓ API_DEEPSEEK_CONFIG 存在
+✓ API_QWEN_CONFIG 存在
+✓ OLLAMA_EMBEDDING_CONFIG 存在
+
+--- 旧配置检查 ---
+✓ DEEPSEEK_CONFIG 已删除
+✓ QWEN_CONFIG 已删除
+✓ EMBEDDING_OLLAMA_CONFIG 已删除
+
+--- Utils函数测试 ---
+✓ get_current_llm_config() 成功,返回类型: <class 'dict'>
+✓ get_current_embedding_config() 成功,返回类型: <class 'dict'>
+
+--- 配置内容验证 ---
+✓ API_QWEN_CONFIG 结构正确
+✓ API_DEEPSEEK_CONFIG 结构正确
+✓ OLLAMA_EMBEDDING_CONFIG 结构正确
+
+=== 配置重构测试完成 ===
+✓ 所有测试通过!配置重构成功!
+```
+
+## 使用示例
+
+### 重构后的配置使用
+
+```python
+# 导入配置
+import app_config
+
+# 使用新的配置名称
+deepseek_config = app_config.API_DEEPSEEK_CONFIG
+qwen_config = app_config.API_QWEN_CONFIG
+ollama_embedding_config = app_config.OLLAMA_EMBEDDING_CONFIG
+api_embedding_config = app_config.API_EMBEDDING_CONFIG
+
+# 使用工具函数
+from common.utils import get_current_llm_config, get_current_embedding_config
+
+current_llm = get_current_llm_config()
+current_embedding = get_current_embedding_config()
+```
+
+## 注意事项
+
+1. **完全向后不兼容**:旧的配置名称已完全移除,如果有其他代码使用了旧的配置名称,需要更新。
+
+2. **测试验证**:建议在使用前运行 `test_config_refactor.py` 确保重构成功。
+
+3. **文档同步**:相关文档已同步更新,确保示例代码使用新的配置名称。
+
+## 总结
+
+本次配置重构成功实现了:
+- ✅ 配置名称标准化
+- ✅ 提高代码可读性
+- ✅ 保持功能完整性
+- ✅ 更新所有相关引用
+- ✅ 通过完整测试验证
+
+重构后的配置结构更加清晰,便于后续维护和扩展。 

+ 567 - 0
docs/ollama 集成方案.md

@@ -0,0 +1,567 @@
+请分析我这个项目的代码,帮我评估一个变更,我希望添加ollama的大模型。会有两种LLM,一种是推理模型,比如qwen3/deepseek,还有一种是embedding模型。
+所以我现在很困惑:
+1.我需要如何在app_config.py上添加配置?
+2.哪些代码需要增加或重构,比如DeepSeekChat和QianWenAI_Chat,需要根据ollama的访问方式在写一个吗?
+3.帮我评估当前项目要兼容ollama,需要修改哪些内容?
+
+通过分析你的项目代码,我来帮你评估添加Ollama支持需要的变更。你的项目架构很清晰,采用了工厂模式和配置驱动的设计,这让添加Ollama支持相对简单。
+
+## 1. 配置层面修改(app_config.py)
+
+```python
+# 在app_config.py中添加Ollama配置
+
+# 使用的模型类型("qwen" 或 "deepseek" 或 "ollama")
+MODEL_TYPE = "ollama"  # 新增ollama选项
+
+# 新增:Ollama模型配置
+OLLAMA_CONFIG = {
+    "base_url": "http://localhost:11434",  # Ollama默认地址
+    "model": "qwen2.5:7b",  # 或者其他ollama模型名称
+    "allow_llm_to_see_data": True,
+    "temperature": 0.7,
+    "n_results": 6,
+    "language": "Chinese",
+    "timeout": 60  # ollama可能需要更长超时时间
+}
+
+# 修改:支持Ollama作为embedding模型
+EMBEDDING_CONFIG = {
+    "provider": "ollama",  # 新增provider字段,可选值:"api", "ollama"
+    "model_name": "nomic-embed-text",  # ollama embedding模型
+    "api_key": os.getenv("EMBEDDING_API_KEY"),  # API方式时使用
+    "base_url": "http://localhost:11434",  # ollama地址
+    "embedding_dimension": 768  # 根据实际模型调整
+}
+
+# 或者分别配置API和Ollama embedding
+EMBEDDING_CONFIG_API = {
+    "model_name": "BAAI/bge-m3",
+    "api_key": os.getenv("EMBEDDING_API_KEY"),
+    "base_url": os.getenv("EMBEDDING_BASE_URL"),
+    "embedding_dimension": 1024
+}
+
+EMBEDDING_CONFIG_OLLAMA = {
+    "model_name": "nomic-embed-text",
+    "base_url": "http://localhost:11434",
+    "embedding_dimension": 768
+}
+
+# 选择使用哪种embedding
+EMBEDDING_TYPE = "ollama"  # 或 "api"
+EMBEDDING_CONFIG = EMBEDDING_CONFIG_OLLAMA if EMBEDDING_TYPE == "ollama" else EMBEDDING_CONFIG_API
+```
+
+## 2. 创建Ollama LLM实现
+
+创建新文件 `customollama/ollama_chat.py`:
+
+```python
+import requests
+import json
+from vanna.base import VannaBase
+from typing import List, Dict, Any
+
+class OllamaChat(VannaBase):
+    def __init__(self, config=None):
+        print("...OllamaChat init...")
+        VannaBase.__init__(self, config=config)
+
+        print("传入的 config 参数如下:")
+        for key, value in self.config.items():
+            print(f"  {key}: {value}")
+
+        # 默认参数
+        self.temperature = 0.7
+        self.base_url = config.get("base_url", "http://localhost:11434")
+        self.model = config.get("model", "qwen2.5:7b")
+        self.timeout = config.get("timeout", 60)
+
+        if "temperature" in config:
+            print(f"temperature is changed to: {config['temperature']}")
+            self.temperature = config["temperature"]
+
+    def system_message(self, message: str) -> any:
+        print(f"system_content: {message}")
+        return {"role": "system", "content": message}
+
+    def user_message(self, message: str) -> any:
+        print(f"\nuser_content: {message}")
+        return {"role": "user", "content": message}
+
+    def assistant_message(self, message: str) -> any:
+        print(f"assistant_content: {message}")
+        return {"role": "assistant", "content": message}
+
+    def submit_prompt(self, prompt, **kwargs) -> str:
+        if prompt is None:
+            raise Exception("Prompt is None")
+
+        if len(prompt) == 0:
+            raise Exception("Prompt is empty")
+
+        # 计算token数量估计
+        num_tokens = 0
+        for message in prompt:
+            num_tokens += len(message["content"]) / 4
+
+        # 确定使用的模型
+        model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
+
+        print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
+
+        # 准备Ollama API请求
+        url = f"{self.base_url}/api/chat"
+        payload = {
+            "model": model,
+            "messages": prompt,
+            "stream": False,
+            "options": {
+                "temperature": self.temperature
+            }
+        }
+
+        try:
+            response = requests.post(
+                url, 
+                json=payload, 
+                timeout=self.timeout,
+                headers={"Content-Type": "application/json"}
+            )
+            response.raise_for_status()
+            
+            result = response.json()
+            return result["message"]["content"]
+            
+        except requests.exceptions.RequestException as e:
+            print(f"Ollama API请求失败: {e}")
+            raise Exception(f"Ollama API调用失败: {str(e)}")
+
+    def generate_sql(self, question: str, **kwargs) -> str:
+        """重写generate_sql方法,增加异常处理"""
+        try:
+            print(f"[DEBUG] 尝试为问题生成SQL: {question}")
+            sql = super().generate_sql(question, **kwargs)
+            
+            if not sql or sql.strip() == "":
+                print(f"[WARNING] 生成的SQL为空")
+                return None
+            
+            # 检查返回内容是否为有效SQL
+            sql_lower = sql.lower().strip()
+            error_indicators = [
+                "insufficient context", "无法生成", "sorry", "cannot", "不能",
+                "no relevant", "no suitable", "unable to", "无法", "抱歉"
+            ]
+            
+            for indicator in error_indicators:
+                if indicator in sql_lower:
+                    print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
+                    return None
+            
+            sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
+            if not any(keyword in sql_lower for keyword in sql_keywords):
+                print(f"[WARNING] 返回内容不像有效SQL: {sql}")
+                return None
+                
+            print(f"[SUCCESS] 成功生成SQL: {sql}")
+            return sql
+            
+        except Exception as e:
+            print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
+            return None
+
+    def generate_question(self, sql: str, **kwargs) -> str:
+        """根据SQL生成中文问题"""
+        prompt = [
+            self.system_message(
+                "请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
+            ),
+            self.user_message(sql)
+        ]
+        response = self.submit_prompt(prompt, **kwargs)
+        return response
+    
+    def chat_with_llm(self, question: str, **kwargs) -> str:
+        """直接与LLM对话"""
+        try:
+            prompt = [
+                self.system_message(
+                    "你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
+                ),
+                self.user_message(question)
+            ]
+            response = self.submit_prompt(prompt, **kwargs)
+            return response
+        except Exception as e:
+            print(f"[ERROR] LLM对话失败: {str(e)}")
+            return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
+```
+
+创建 `customollama/__init__.py`:
+
+```python
+from .ollama_chat import OllamaChat
+```
+
+## 3. 修改Embedding功能
+
+修改 `embedding_function.py`,添加Ollama支持:
+
+```python
+import requests
+import time
+import numpy as np
+from typing import List, Callable
+
+class OllamaEmbeddingFunction:
+    def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
+        self.model_name = model_name
+        self.base_url = base_url
+        self.embedding_dimension = embedding_dimension
+        self.max_retries = 2
+        self.retry_interval = 2
+
+    def __call__(self, input) -> List[List[float]]:
+        """为文本列表生成嵌入向量"""
+        if not isinstance(input, list):
+            input = [input]
+            
+        embeddings = []
+        for text in input:
+            try:
+                embedding = self.generate_embedding(text)
+                embeddings.append(embedding)
+            except Exception as e:
+                print(f"获取embedding时出错: {e}")
+                embeddings.append([0.0] * self.embedding_dimension)
+                
+        return embeddings
+    
+    def generate_embedding(self, text: str) -> List[float]:
+        """为单个文本生成嵌入向量"""
+        print(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
+        
+        if not text or len(text.strip()) == 0:
+            print("输入文本为空,返回零向量")
+            return [0.0] * self.embedding_dimension
+
+        url = f"{self.base_url}/api/embeddings"
+        payload = {
+            "model": self.model_name,
+            "prompt": text
+        }
+        
+        retries = 0
+        while retries <= self.max_retries:
+            try:
+                response = requests.post(
+                    url, 
+                    json=payload,
+                    timeout=30
+                )
+                
+                if response.status_code != 200:
+                    error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
+                    print(error_msg)
+                    
+                    if response.status_code in (429, 500, 502, 503, 504):
+                        retries += 1
+                        if retries <= self.max_retries:
+                            wait_time = self.retry_interval * (2 ** (retries - 1))
+                            print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                            time.sleep(wait_time)
+                            continue
+                    
+                    raise ValueError(error_msg)
+                
+                result = response.json()
+                
+                if "embedding" in result:
+                    vector = result["embedding"]
+                    
+                    # 验证向量维度
+                    actual_dim = len(vector)
+                    if actual_dim != self.embedding_dimension:
+                        print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
+                        # 如果维度不匹配,可以选择截断或填充
+                        if actual_dim > self.embedding_dimension:
+                            vector = vector[:self.embedding_dimension]
+                        else:
+                            vector.extend([0.0] * (self.embedding_dimension - actual_dim))
+                    
+                    print(f"成功生成Ollama embedding向量,维度: {len(vector)}")
+                    return vector
+                else:
+                    error_msg = f"Ollama API返回格式异常: {result}"
+                    print(error_msg)
+                    raise ValueError(error_msg)
+                
+            except Exception as e:
+                print(f"生成Ollama embedding时出错: {str(e)}")
+                retries += 1
+                
+                if retries <= self.max_retries:
+                    wait_time = self.retry_interval * (2 ** (retries - 1))
+                    print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
+                    time.sleep(wait_time)
+                else:
+                    print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
+                    return [0.0] * self.embedding_dimension
+        
+        raise RuntimeError("生成Ollama embedding失败")
+
+    def test_connection(self, test_text="测试文本") -> dict:
+        """测试Ollama嵌入模型的连接"""
+        result = {
+            "success": False,
+            "model": self.model_name,
+            "base_url": self.base_url,
+            "message": "",
+            "actual_dimension": None,
+            "expected_dimension": self.embedding_dimension
+        }
+        
+        try:
+            print(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
+            print(f"Ollama服务地址: {self.base_url}")
+            
+            vector = self.generate_embedding(test_text)
+            actual_dimension = len(vector)
+            
+            result["success"] = True
+            result["actual_dimension"] = actual_dimension
+            
+            if actual_dimension != self.embedding_dimension:
+                result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
+            else:
+                result["message"] = f"Ollama连接测试成功,向量维度: {actual_dimension}"
+                
+            return result
+            
+        except Exception as e:
+            result["message"] = f"Ollama连接测试失败: {str(e)}"
+            return result
+
+# 修改原有的get_embedding_function函数
+def get_embedding_function():
+    """根据配置创建合适的EmbeddingFunction实例"""
+    try:
+        import app_config
+    except ImportError:
+        raise ImportError("无法导入 app_config.py。")
+
+    try:
+        embedding_config = app_config.EMBEDDING_CONFIG
+    except AttributeError:
+        raise AttributeError("app_config.py 中缺少 EMBEDDING_CONFIG 配置。")
+
+    # 检查是否使用Ollama embedding
+    provider = embedding_config.get("provider", "api")
+    
+    if provider == "ollama":
+        print("使用Ollama Embedding模型")
+        return OllamaEmbeddingFunction(
+            model_name=embedding_config["model_name"],
+            base_url=embedding_config["base_url"],
+            embedding_dimension=embedding_config["embedding_dimension"]
+        )
+    else:
+        print("使用API Embedding模型")
+        # 原有的EmbeddingFunction逻辑
+        api_key = embedding_config["api_key"]
+        model_name = embedding_config["model_name"]
+        base_url = embedding_config["base_url"]
+        embedding_dimension = embedding_config["embedding_dimension"]
+        
+        if api_key is None:
+            raise KeyError("EMBEDDING_CONFIG 中的 'api_key' 未设置。")
+            
+        return EmbeddingFunction(
+            model_name=model_name,
+            api_key=api_key,
+            base_url=base_url,
+            embedding_dimension=embedding_dimension
+        )
+
+# 修改测试函数
+def test_embedding_connection() -> dict:
+    """测试嵌入模型连接"""
+    try:
+        embedding_function = get_embedding_function()
+        test_result = embedding_function.test_connection()
+        
+        if test_result["success"]:
+            print(f"嵌入模型连接测试成功!")
+            if "警告" in test_result["message"]:
+                print(test_result["message"])
+        else:
+            print(f"嵌入模型连接测试失败: {test_result['message']}")
+            
+        return test_result
+        
+    except Exception as e:
+        error_message = f"无法测试嵌入模型连接: {str(e)}"
+        print(error_message)
+        return {
+            "success": False,
+            "message": error_message
+        }
+```
+
+## 4. 修改工厂模式(vanna_llm_factory.py)
+
+```python
+"""
+Vanna LLM 工厂文件,支持多种LLM提供商
+"""
+from vanna.chromadb import ChromaDB_VectorStore
+from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
+from customdeepseek.custom_deepseek_chat import DeepSeekChat
+from customollama.ollama_chat import OllamaChat  # 新增
+import app_config 
+from embedding_function import get_embedding_function
+import os
+
+class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        QianWenAI_Chat.__init__(self, config=config)
+
+class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        DeepSeekChat.__init__(self, config=config)
+
+# 新增Ollama支持
+class Vanna_Ollama_ChromaDB(ChromaDB_VectorStore, OllamaChat):
+    def __init__(self, config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        OllamaChat.__init__(self, config=config)
+
+def create_vanna_instance(config_module=None):
+    """工厂函数:创建并初始化Vanna实例"""
+    if config_module is None:
+        config_module = app_config
+
+    model_type = config_module.MODEL_TYPE.lower()
+    
+    config = {}
+    if model_type == "deepseek":
+        config = config_module.API_DEEPSEEK_CONFIG.copy()
+        print(f"创建DeepSeek模型实例,使用模型: {config['model']}")
+        if not config.get("api_key"):
+            print(f"\n错误: DeepSeek API密钥未设置或为空")
+            print(f"请在.env文件中设置DEEPSEEK_API_KEY环境变量")
+            import sys
+            sys.exit(1)
+            
+    elif model_type == "qwen":
+        config = config_module.API_QWEN_CONFIG.copy()
+        print(f"创建Qwen模型实例,使用模型: {config['model']}")
+        if not config.get("api_key"):
+            print(f"\n错误: Qwen API密钥未设置或为空")
+            print(f"请在.env文件中设置QWEN_API_KEY环境变量")
+            import sys
+            sys.exit(1)
+            
+    elif model_type == "ollama":  # 新增
+        config = config_module.OLLAMA_CONFIG.copy()
+        print(f"创建Ollama模型实例,使用模型: {config['model']}")
+        print(f"Ollama服务地址: {config['base_url']}")
+        # Ollama通常不需要API密钥,但可以检查服务是否可用
+        try:
+            import requests
+            response = requests.get(f"{config['base_url']}/api/tags", timeout=5)
+            if response.status_code != 200:
+                print(f"警告: 无法连接到Ollama服务 ({config['base_url']})")
+        except Exception as e:
+            print(f"警告: Ollama服务连接测试失败: {e}")
+            
+    else:
+        raise ValueError(f"不支持的模型类型: {model_type}") 
+    
+    # 获取embedding函数(支持API和Ollama两种方式)
+    embedding_function = get_embedding_function()
+    config["embedding_function"] = embedding_function
+    
+    # 打印embedding配置信息
+    embedding_config = config_module.EMBEDDING_CONFIG
+    provider = embedding_config.get("provider", "api")
+    print(f"已配置使用 {provider.upper()} 嵌入模型: {embedding_config['model_name']}, 维度: {embedding_config['embedding_dimension']}")
+    
+    # 设置ChromaDB路径
+    project_root = os.path.dirname(os.path.abspath(__file__))
+    config["path"] = project_root
+    print(f"已配置使用ChromaDB作为向量数据库,路径:{project_root}")
+    
+    # 创建对应的Vanna实例
+    vn = None
+    if model_type == "deepseek":
+        vn = Vanna_DeepSeek_ChromaDB(config=config)
+        print("创建DeepSeek+ChromaDB实例")
+    elif model_type == "qwen":
+        vn = Vanna_Qwen_ChromaDB(config=config)
+        print("创建Qwen+ChromaDB实例")
+    elif model_type == "ollama":  # 新增
+        vn = Vanna_Ollama_ChromaDB(config=config)
+        print("创建Ollama+ChromaDB实例")
+    
+    if vn is None:
+        raise ValueError(f"未能成功创建Vanna实例,不支持的模型类型: {model_type}")
+
+    # 连接到业务数据库
+    vn.connect_to_postgres(**config_module.APP_DB_CONFIG)           
+    print(f"已连接到业务数据库: "
+          f"{config_module.APP_DB_CONFIG['host']}:"
+          f"{config_module.APP_DB_CONFIG['port']}/"
+          f"{config_module.APP_DB_CONFIG['dbname']}")
+    return vn
+```
+
+## 5. 需要新增的文件结构
+
+```
+项目根目录/
+├── customollama/
+│   ├── __init__.py
+│   └── ollama_chat.py
+```
+
+## 6. 使用示例
+
+在你的项目中使用Ollama:
+
+### 使用Ollama推理模型:
+```python
+# 在app_config.py中设置
+MODEL_TYPE = "ollama"
+OLLAMA_CONFIG = {
+    "base_url": "http://localhost:11434",
+    "model": "qwen2.5:7b",  # 或其他模型
+    # ... 其他配置
+}
+```
+
+### 使用Ollama embedding模型:
+```python
+# 在app_config.py中设置
+EMBEDDING_CONFIG = {
+    "provider": "ollama",
+    "model_name": "nomic-embed-text",
+    "base_url": "http://localhost:11434",
+    "embedding_dimension": 768
+}
+```
+
+## 总结
+
+通过以上修改,你的项目就可以支持Ollama了。主要变更包括:
+
+1. **配置扩展**:在`app_config.py`中添加Ollama相关配置
+2. **新增LLM实现**:创建`OllamaChat`类
+3. **扩展Embedding支持**:修改`embedding_function.py`支持Ollama embedding
+4. **工厂模式扩展**:在`vanna_llm_factory.py`中添加Ollama支持
+
+这种设计保持了现有代码的兼容性,你可以通过配置轻松切换不同的LLM提供商。

+ 371 - 0
docs/ollama_integration_guide.md

@@ -0,0 +1,371 @@
+# Ollama 集成使用指南
+
+本指南介绍如何在项目中使用 Ollama 作为 LLM 和/或 Embedding 提供商。
+
+## 概述
+
+项目现在支持以下配置组合:
+
+1. **全API模式**:API LLM + API Embedding
+2. **全Ollama模式**:Ollama LLM + Ollama Embedding  
+3. **混合模式1**:API LLM + Ollama Embedding
+4. **混合模式2**:Ollama LLM + API Embedding
+
+## 前置条件
+
+### 安装和启动 Ollama
+
+1. 从 [Ollama官网](https://ollama.ai) 下载并安装 Ollama
+2. 启动 Ollama 服务:
+   ```bash
+   ollama serve
+   ```
+3. 拉取所需的模型:
+   ```bash
+   # LLM模型(选择其中一个)
+   ollama pull qwen2.5:7b
+   ollama pull deepseek-r1:7b
+   ollama pull llama3:8b
+   
+   # Embedding模型
+   ollama pull nomic-embed-text
+   ```
+
+## 配置说明
+
+### 1. 基本配置参数
+
+在 `app_config.py` 中设置以下参数:
+
+```python
+# 模型提供商类型
+LLM_MODEL_TYPE = "ollama"        # "api" 或 "ollama"
+EMBEDDING_MODEL_TYPE = "ollama"  # "api" 或 "ollama"
+
+# API模式下的模型选择(Ollama模式下不使用)
+API_LLM_MODEL = "qwen"           # "qwen" 或 "deepseek"
+
+# 向量数据库类型
+VECTOR_DB_TYPE = "pgvector"      # "chromadb" 或 "pgvector"
+```
+
+### 2. Ollama LLM 配置
+
+```python
+OLLAMA_LLM_CONFIG = {
+    "base_url": "http://localhost:11434",  # Ollama服务地址
+    "model": "qwen2.5:7b",                 # 模型名称
+    "allow_llm_to_see_data": True,
+    "temperature": 0.7,
+    "n_results": 6,
+    "language": "Chinese",
+    "timeout": 60                          # 超时时间(秒)
+}
+```
+
+### 3. Ollama Embedding 配置
+
+```python
+OLLAMA_EMBEDDING_CONFIG = {
+    "model_name": "nomic-embed-text",      # Embedding模型名称
+    "base_url": "http://localhost:11434",  # Ollama服务地址
+    "embedding_dimension": 768             # 向量维度
+}
+```
+
+## 使用示例
+
+### 示例1:全Ollama模式
+
+```python
+# app_config.py
+LLM_MODEL_TYPE = "ollama"
+EMBEDDING_MODEL_TYPE = "ollama"
+VECTOR_DB_TYPE = "chromadb"
+
+OLLAMA_LLM_CONFIG = {
+    "base_url": "http://localhost:11434",
+    "model": "qwen2.5:7b",
+    "temperature": 0.7,
+    "timeout": 60
+}
+
+OLLAMA_EMBEDDING_CONFIG = {
+    "model_name": "nomic-embed-text",
+    "base_url": "http://localhost:11434",
+    "embedding_dimension": 768
+}
+```
+
+### 示例2:混合模式(API LLM + Ollama Embedding)
+
+```python
+# app_config.py
+LLM_MODEL_TYPE = "api"
+EMBEDDING_MODEL_TYPE = "ollama"
+API_LLM_MODEL = "qwen"
+VECTOR_DB_TYPE = "pgvector"
+
+# 使用现有的 API_QWEN_CONFIG
+# 使用 OLLAMA_EMBEDDING_CONFIG
+```
+
+### 示例3:混合模式(Ollama LLM + API Embedding)
+
+```python
+# app_config.py
+LLM_MODEL_TYPE = "ollama"
+EMBEDDING_MODEL_TYPE = "api"
+VECTOR_DB_TYPE = "chromadb"
+
+# 使用 OLLAMA_LLM_CONFIG
+# 使用现有的 API_EMBEDDING_CONFIG
+```
+
+## 代码使用
+
+### 1. 使用工具函数检查配置
+
+```python
+from common.utils import (
+    is_using_ollama_llm,
+    is_using_ollama_embedding,
+    get_current_model_info,
+    print_current_config
+)
+
+# 检查当前配置
+print_current_config()
+
+# 检查是否使用Ollama
+if is_using_ollama_llm():
+    print("当前使用Ollama LLM")
+
+if is_using_ollama_embedding():
+    print("当前使用Ollama Embedding")
+
+# 获取模型信息
+model_info = get_current_model_info()
+print(model_info)
+```
+
+### 2. 创建Vanna实例
+
+```python
+from vanna_llm_factory import create_vanna_instance
+
+# 根据配置自动创建合适的实例
+vn = create_vanna_instance()
+
+# 使用实例
+sql = vn.generate_sql("查询所有用户的信息")
+print(sql)
+```
+
+### 3. 直接使用Ollama组件
+
+```python
+# 直接使用Ollama LLM
+from customollama.ollama_chat import OllamaChat
+
+config = {
+    "base_url": "http://localhost:11434",
+    "model": "qwen2.5:7b",
+    "temperature": 0.7
+}
+
+ollama_llm = OllamaChat(config=config)
+response = ollama_llm.chat_with_llm("你好")
+
+# 直接使用Ollama Embedding
+from customollama.ollama_embedding import OllamaEmbeddingFunction
+
+embedding_func = OllamaEmbeddingFunction(
+    model_name="nomic-embed-text",
+    base_url="http://localhost:11434",
+    embedding_dimension=768
+)
+
+embeddings = embedding_func(["文本1", "文本2"])
+```
+
+## 测试和验证
+
+### 1. 运行配置测试
+
+```bash
+python test_config_utils.py
+```
+
+### 2. 运行Ollama集成测试
+
+```bash
+python test_ollama_integration.py
+```
+
+### 3. 测试连接
+
+```python
+# 测试Ollama LLM连接
+from customollama.ollama_chat import OllamaChat
+
+config = {"base_url": "http://localhost:11434", "model": "qwen2.5:7b"}
+ollama_llm = OllamaChat(config=config)
+result = ollama_llm.test_connection()
+print(result)
+
+# 测试Ollama Embedding连接
+from customollama.ollama_embedding import OllamaEmbeddingFunction
+
+embedding_func = OllamaEmbeddingFunction(
+    model_name="nomic-embed-text",
+    base_url="http://localhost:11434",
+    embedding_dimension=768
+)
+result = embedding_func.test_connection()
+print(result)
+```
+
+## 常见问题
+
+### 1. 连接失败
+
+**问题**:`Ollama API调用失败: Connection refused`
+
+**解决方案**:
+- 确保Ollama服务正在运行:`ollama serve`
+- 检查服务地址是否正确(默认:`http://localhost:11434`)
+- 确保防火墙没有阻止连接
+
+### 2. 模型不存在
+
+**问题**:`model 'qwen2.5:7b' not found`
+
+**解决方案**:
+- 拉取所需模型:`ollama pull qwen2.5:7b`
+- 检查可用模型:`ollama list`
+
+### 3. 向量维度不匹配
+
+**问题**:`向量维度不匹配: 期望 768, 实际 384`
+
+**解决方案**:
+- 更新配置中的 `embedding_dimension` 为实际维度
+- 或者选择匹配的embedding模型
+
+### 4. 超时错误
+
+**问题**:`Ollama API调用超时`
+
+**解决方案**:
+- 增加 `timeout` 配置值
+- 检查模型是否已完全加载
+- 考虑使用更小的模型
+
+## 性能优化建议
+
+### 1. 模型选择
+
+- **小型模型**:`qwen2.5:7b`, `llama3:8b` - 适合资源有限的环境
+- **大型模型**:`qwen2.5:14b`, `deepseek-r1:32b` - 适合性能要求高的场景
+
+### 2. 配置优化
+
+```python
+# 针对性能优化的配置
+OLLAMA_LLM_CONFIG = {
+    "base_url": "http://localhost:11434",
+    "model": "qwen2.5:7b",
+    "temperature": 0.1,  # 降低随机性,提高一致性
+    "timeout": 120,      # 增加超时时间
+}
+```
+
+### 3. 缓存策略
+
+- 启用向量数据库缓存
+- 使用会话感知缓存
+- 合理设置缓存过期时间
+
+## 部署注意事项
+
+### 1. 生产环境
+
+- 确保Ollama服务的稳定性和可用性
+- 配置适当的资源限制(CPU、内存、GPU)
+- 设置监控和日志记录
+
+### 2. 安全考虑
+
+- 限制Ollama服务的网络访问
+- 使用防火墙保护服务端口
+- 定期更新Ollama和模型
+
+### 3. 备份和恢复
+
+- 备份模型文件和配置
+- 准备API服务作为备用方案
+- 测试故障转移流程
+
+## 架构说明
+
+### 统一组合类管理
+
+项目采用了统一的组合类管理方式,所有LLM与向量数据库的组合都在 `common/vanna_combinations.py` 中定义:
+
+```python
+# 可用的组合类
+from common.vanna_combinations import (
+    Vanna_Qwen_ChromaDB,
+    Vanna_DeepSeek_ChromaDB,
+    Vanna_Ollama_ChromaDB,
+    Vanna_Qwen_PGVector,
+    Vanna_DeepSeek_PGVector,
+    Vanna_Ollama_PGVector,
+    get_vanna_class,
+    print_available_combinations
+)
+
+# 动态获取组合类
+cls = get_vanna_class("ollama", "chromadb")  # 返回 Vanna_Ollama_ChromaDB
+
+# 查看所有可用组合
+print_available_combinations()
+```
+
+### 工厂模式
+
+`vanna_llm_factory.py` 使用统一的组合类文件,自动根据配置选择合适的组合:
+
+```python
+from vanna_llm_factory import create_vanna_instance
+
+# 根据 app_config.py 中的配置自动创建实例
+vn = create_vanna_instance()
+```
+
+## 测试和验证
+
+### 1. 运行配置测试
+
+```bash
+python test_config_utils.py
+```
+
+### 2. 运行Ollama集成测试
+
+```bash
+python test_ollama_integration.py
+```
+
+### 3. 运行组合类测试
+
+```bash
+python test_vanna_combinations.py
+```
+
+## 更多资源
+
+- [Ollama官方文档](https://ollama.ai/docs)
+- [支持的模型列表](https://ollama.ai/library)
+- [API参考文档](https://github.com/ollama/ollama/blob/main/docs/api.md) 

+ 187 - 0
docs/training_integration_fixes.md

@@ -0,0 +1,187 @@
+# Training目录集成修复总结
+
+本文档总结了为使training目录与新的Ollama集成配置结构兼容所做的修复。
+
+## 修复的问题
+
+### 1. 配置访问方式更新
+
+**问题**:training目录中的代码直接访问旧的配置结构,与新的配置系统不兼容。
+
+**修复**:
+
+#### `training/vanna_trainer.py`
+- **修复前**:直接访问 `app_config.EMBEDDING_CONFIG`
+- **修复后**:使用 `common.utils` 中的工具函数
+
+```python
+# 修复前
+embedding_model = app_config.EMBEDDING_CONFIG.get('model_name')
+
+# 修复后
+from common.utils import get_current_embedding_config, get_current_model_info
+embedding_config = get_current_embedding_config()
+model_info = get_current_model_info()
+```
+
+#### `training/run_training.py`
+- **修复前**:直接访问 `app_config.EMBEDDING_CONFIG`
+- **修复后**:同样使用新的工具函数,并提供回退机制
+
+```python
+# 修复后
+try:
+    from common.utils import get_current_embedding_config, get_current_model_info
+    embedding_config = get_current_embedding_config()
+    model_info = get_current_model_info()
+except ImportError as e:
+    # 回退到旧的配置访问方式
+            embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
+```
+
+### 2. 向后兼容性
+
+**特点**:
+- 所有修复都包含了错误处理和回退机制
+- 如果新的配置工具函数不可用,会自动回退到旧的访问方式
+- 不会破坏现有的功能
+
+### 3. 配置验证
+
+**验证的配置项**:
+- ✅ `TRAINING_BATCH_PROCESSING_ENABLED` - 存在
+- ✅ `TRAINING_BATCH_SIZE` - 存在  
+- ✅ `TRAINING_MAX_WORKERS` - 存在
+- ✅ `TRAINING_DATA_PATH` - 存在
+- ✅ `PGVECTOR_CONFIG` - 存在
+- ✅ 新的embedding配置工具函数 - 已实现
+
+## 测试验证
+
+创建了 `test_training_integration.py` 脚本来验证修复效果:
+
+```bash
+python test_training_integration.py
+```
+
+### 测试覆盖范围
+
+1. **训练模块导入** - 验证所有训练函数可以正常导入
+2. **配置访问** - 验证新旧配置访问方式都能正常工作
+3. **Vanna实例创建** - 验证工厂函数能正常创建实例
+4. **批处理器** - 验证BatchProcessor类能正常工作
+5. **训练函数** - 验证所有训练函数都是可调用的
+6. **Embedding连接** - 验证embedding模型连接
+7. **run_training脚本** - 验证主训练脚本的基本功能
+
+## 支持的配置组合
+
+现在training目录支持所有新的配置组合:
+
+### 1. 全API模式
+```python
+LLM_MODEL_TYPE = "api"
+EMBEDDING_MODEL_TYPE = "api"
+API_LLM_MODEL = "qwen"  # 或 "deepseek"
+```
+
+### 2. 全Ollama模式
+```python
+LLM_MODEL_TYPE = "ollama"
+EMBEDDING_MODEL_TYPE = "ollama"
+```
+
+### 3. 混合模式1(API LLM + Ollama Embedding)
+```python
+LLM_MODEL_TYPE = "api"
+EMBEDDING_MODEL_TYPE = "ollama"
+API_LLM_MODEL = "qwen"
+```
+
+### 4. 混合模式2(Ollama LLM + API Embedding)
+```python
+LLM_MODEL_TYPE = "ollama"
+EMBEDDING_MODEL_TYPE = "api"
+```
+
+## 使用方法
+
+### 1. 运行训练
+
+```bash
+# 使用默认配置
+python training/run_training.py
+
+# 指定训练数据路径
+python training/run_training.py --data_path /path/to/training/data
+```
+
+### 2. 编程方式使用
+
+```python
+from training import (
+    train_ddl,
+    train_documentation,
+    train_sql_example,
+    train_question_sql_pair,
+    flush_training,
+    shutdown_trainer
+)
+
+# 训练DDL
+train_ddl("CREATE TABLE users (id INT, name VARCHAR(50));")
+
+# 训练文档
+train_documentation("用户表包含用户的基本信息")
+
+# 训练SQL示例
+train_sql_example("SELECT * FROM users WHERE age > 18;")
+
+# 训练问答对
+train_question_sql_pair("查询所有成年用户", "SELECT * FROM users WHERE age >= 18;")
+
+# 完成训练
+flush_training()
+shutdown_trainer()
+```
+
+## 注意事项
+
+1. **配置优先级**:新的配置工具函数优先,如果不可用则回退到旧配置
+2. **错误处理**:所有配置访问都包含了适当的错误处理
+3. **向后兼容**:现有的训练脚本和代码无需修改即可继续使用
+4. **性能优化**:批处理功能仍然可用,提高训练效率
+
+## 文件修改清单
+
+### 修改的文件
+- `training/vanna_trainer.py` - 更新配置访问方式
+- `training/run_training.py` - 更新配置访问方式
+
+### 新增的文件
+- `test_training_integration.py` - 训练集成测试脚本
+- `docs/training_integration_fixes.md` - 本文档
+
+### 未修改的文件
+- `training/__init__.py` - 无需修改
+- `training/data/` - 训练数据目录保持不变
+
+## 验证步骤
+
+1. **运行集成测试**:
+   ```bash
+   python test_training_integration.py
+   ```
+
+2. **测试训练功能**:
+   ```bash
+   python training/run_training.py --data_path training/data
+   ```
+
+3. **验证不同配置**:
+   - 修改 `app_config.py` 中的配置
+   - 重新运行测试和训练
+
+## 总结
+
+通过这些修复,training目录现在完全兼容新的Ollama集成配置结构,同时保持了向后兼容性。用户可以无缝地在不同的LLM和embedding提供商之间切换,而无需修改训练代码。 

+ 44 - 35
embedding_function.py

@@ -300,46 +300,55 @@ def test_embedding_connection() -> dict:
             "message": error_message
         }
 
-def get_embedding_function() -> EmbeddingFunction:
+def get_embedding_function():
     """
-    从 app_config.py 的 EMBEDDING_CONFIG 字典加载配置并创建 EmbeddingFunction 实例。
-    如果任何必需的配置未找到,则抛出异常。
-
+    根据当前配置创建合适的EmbeddingFunction实例
+    支持API和Ollama两种提供商
+    
     Returns:
-        EmbeddingFunction: EmbeddingFunction 的实例。
-
+        EmbeddingFunction或OllamaEmbeddingFunction: 根据配置类型返回相应的实例
+        
     Raises:
-        ImportError: 如果 app_config.py 无法导入。
-        AttributeError: 如果 app_config.py 中缺少 EMBEDDING_CONFIG。
-        KeyError: 如果 EMBEDDING_CONFIG 字典中缺少任何必要的键。
+        ImportError: 如果无法导入必要的模块
+        ValueError: 如果配置无效
     """
     try:
-        import app_config
+        from common.utils import get_current_embedding_config, is_using_ollama_embedding
     except ImportError:
-        raise ImportError("无法导入 app_config.py。请确保该文件存在且在PYTHONPATH中。")
-
-    try:
-        embedding_config_dict = app_config.EMBEDDING_CONFIG
-    except AttributeError:
-        raise AttributeError("app_config.py 中缺少 EMBEDDING_CONFIG 配置字典。")
-
-    try:
-        api_key = embedding_config_dict["api_key"]
-        model_name = embedding_config_dict["model_name"]
-        base_url = embedding_config_dict["base_url"]
-        embedding_dimension = embedding_config_dict["embedding_dimension"]
-        
-        if api_key is None:
-            # 明确指出 api_key (可能来自环境变量) 未设置的问题
-            raise KeyError("EMBEDDING_CONFIG 中的 'api_key' 未设置 (可能环境变量 EMBEDDING_API_KEY 未定义)。")
+        raise ImportError("无法导入 common.utils,请确保该文件存在")
+    
+    # 获取当前embedding配置
+    embedding_config = get_current_embedding_config()
+    
+    if is_using_ollama_embedding():
+        # 使用Ollama Embedding
+        try:
+            from customollama.ollama_embedding import OllamaEmbeddingFunction
+        except ImportError:
+            raise ImportError("无法导入 OllamaEmbeddingFunction,请确保 customollama 包存在")
             
-    except KeyError as e:
-        # 将原始的KeyError e 作为原因传递,可以提供更详细的上下文,比如哪个键确实缺失了
-        raise KeyError(f"app_config.py 的 EMBEDDING_CONFIG 字典中缺少必要的键或值无效:{e}")
+        return OllamaEmbeddingFunction(
+            model_name=embedding_config["model_name"],
+            base_url=embedding_config["base_url"],
+            embedding_dimension=embedding_config["embedding_dimension"]
+        )
+    else:
+        # 使用API Embedding
+        try:
+            api_key = embedding_config["api_key"]
+            model_name = embedding_config["model_name"]
+            base_url = embedding_config["base_url"]
+            embedding_dimension = embedding_config["embedding_dimension"]
+            
+            if api_key is None:
+                raise KeyError("API模式下 'api_key' 未设置 (可能环境变量 EMBEDDING_API_KEY 未定义)")
+                
+        except KeyError as e:
+            raise KeyError(f"API Embedding配置中缺少必要的键或值无效:{e}")
 
-    return EmbeddingFunction(
-        model_name=model_name,
-        api_key=api_key,
-        base_url=base_url,
-        embedding_dimension=embedding_dimension
-    )
+        return EmbeddingFunction(
+            model_name=model_name,
+            api_key=api_key,
+            base_url=base_url,
+            embedding_dimension=embedding_dimension
+        )

+ 19 - 4
training/run_training.py

@@ -504,7 +504,7 @@ def main():
     check_embedding_model_connection()
     
     # 根据配置的向量数据库类型显示相应信息
-    vector_db_type = app_config.VECTOR_DB_NAME.lower()
+    vector_db_type = app_config.VECTOR_DB_TYPE.lower()
     
     if vector_db_type == "chromadb":
         # 打印ChromaDB相关信息
@@ -588,9 +588,24 @@ def main():
     
     # 输出embedding模型信息
     print("\n===== Embedding模型信息 =====")
-    print(f"模型名称: {app_config.EMBEDDING_CONFIG.get('model_name')}")
-    print(f"向量维度: {app_config.EMBEDDING_CONFIG.get('embedding_dimension')}")
-    print(f"API服务: {app_config.EMBEDDING_CONFIG.get('base_url')}")
+    try:
+        from common.utils import get_current_embedding_config, get_current_model_info
+        
+        embedding_config = get_current_embedding_config()
+        model_info = get_current_model_info()
+        
+        print(f"模型类型: {model_info['embedding_type']}")
+        print(f"模型名称: {model_info['embedding_model']}")
+        print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
+        if 'base_url' in embedding_config:
+            print(f"API服务: {embedding_config['base_url']}")
+    except ImportError as e:
+        print(f"警告: 无法导入配置工具函数: {e}")
+        # 回退到旧的配置访问方式
+        embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
+        print(f"模型名称: {embedding_config.get('model_name', '未知')}")
+        print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
+        print(f"API服务: {embedding_config.get('base_url', '未知')}")
     
     # 根据配置显示向量数据库信息
     if vector_db_type == "chromadb":

+ 21 - 10
training/vanna_trainer.py

@@ -20,16 +20,27 @@ from vanna_llm_factory import create_vanna_instance
 
 vn = create_vanna_instance()
 
-# 直接从配置文件获取模型名称
-embedding_model = app_config.EMBEDDING_CONFIG.get('model_name')
-print(f"\n===== Embedding模型信息 =====")
-print(f"模型名称: {embedding_model}")
-if hasattr(app_config, 'EMBEDDING_CONFIG'):
-    if 'embedding_dimension' in app_config.EMBEDDING_CONFIG:
-        print(f"向量维度: {app_config.EMBEDDING_CONFIG['embedding_dimension']}")
-    if 'base_url' in app_config.EMBEDDING_CONFIG:
-        print(f"API服务: {app_config.EMBEDDING_CONFIG['base_url']}")
-print("==============================")
+# 使用新的配置工具函数获取embedding配置
+try:
+    from common.utils import get_current_embedding_config, get_current_model_info
+    
+    embedding_config = get_current_embedding_config()
+    model_info = get_current_model_info()
+    
+    print(f"\n===== Embedding模型信息 =====")
+    print(f"模型类型: {model_info['embedding_type']}")
+    print(f"模型名称: {model_info['embedding_model']}")
+    print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
+    if 'base_url' in embedding_config:
+        print(f"API服务: {embedding_config['base_url']}")
+    print("==============================")
+except ImportError as e:
+    print(f"警告: 无法导入配置工具函数: {e}")
+    print("使用默认配置...")
+    embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
+    print(f"\n===== Embedding模型信息 (默认) =====")
+    print(f"模型名称: {embedding_config.get('model_name', '未知')}")
+    print("==============================")
 
 # 从app_config获取训练批处理配置
 BATCH_PROCESSING_ENABLED = app_config.TRAINING_BATCH_PROCESSING_ENABLED

+ 53 - 63
vanna_llm_factory.py

@@ -1,44 +1,14 @@
 """
-Vanna LLM 工厂文件,专注于 ChromaDB 并简化配置。
+Vanna LLM 工厂文件,支持多种LLM提供商和向量数据库
 """
 import app_config, os
-from vanna.chromadb import ChromaDB_VectorStore  # 从 Vanna 系统获取
-from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
-from customdeepseek.custom_deepseek_chat import DeepSeekChat
 from embedding_function import get_embedding_function
-from custompgvector import PG_VectorStore
-
-class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
-    def __init__(self, config=None):
-        ChromaDB_VectorStore.__init__(self, config=config)
-        QianWenAI_Chat.__init__(self, config=config)
-
-class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
-    def __init__(self, config=None):
-        ChromaDB_VectorStore.__init__(self, config=config)
-        DeepSeekChat.__init__(self, config=config)
-
-class Vanna_Qwen_PGVector(PG_VectorStore, QianWenAI_Chat):
-    def __init__(self, config=None):
-        PG_VectorStore.__init__(self, config=config)
-        QianWenAI_Chat.__init__(self, config=config)
-
-class Vanna_DeepSeek_PGVector(PG_VectorStore, DeepSeekChat):
-    def __init__(self, config=None):
-        PG_VectorStore.__init__(self, config=config)
-        DeepSeekChat.__init__(self, config=config)
-
-# 组合映射表
-LLM_VECTOR_DB_MAP = {
-    ('deepseek', 'chromadb'): Vanna_DeepSeek_ChromaDB,
-    ('deepseek', 'pgvector'): Vanna_DeepSeek_PGVector,
-    ('qwen', 'chromadb'): Vanna_Qwen_ChromaDB,
-    ('qwen', 'pgvector'): Vanna_Qwen_PGVector,
-}
+from common.vanna_combinations import get_vanna_class, print_available_combinations
 
 def create_vanna_instance(config_module=None):
     """
-    工厂函数:创建并初始化一个Vanna实例 (LLM 和 ChromaDB 特定版本)
+    工厂函数:创建并初始化一个Vanna实例
+    支持API和Ollama两种LLM提供商,以及ChromaDB和PgVector两种向量数据库
     
     Args:
         config_module: 配置模块,默认为None时使用 app_config
@@ -49,49 +19,69 @@ def create_vanna_instance(config_module=None):
     if config_module is None:
         config_module = app_config
 
-    llm_model_name  = config_module.LLM_MODEL_NAME.lower()
-    vector_db_name = config_module.VECTOR_DB_NAME.lower()   
+    try:
+        from common.utils import (
+            get_current_llm_config, 
+            get_current_vector_db_config,
+            get_current_model_info,
+            is_using_ollama_llm,
+            print_current_config
+        )
+    except ImportError:
+        raise ImportError("无法导入 common.utils,请确保该文件存在")
 
-    if (llm_model_name, vector_db_name) not in LLM_VECTOR_DB_MAP:
-        raise ValueError(f"不支持的模型类型: {llm_model_name} 或 向量数据库类型: {vector_db_name}")
+    # 打印当前配置信息
+    print_current_config()
     
-    config = {}
-    if llm_model_name == "deepseek":
-        config = config_module.DEEPSEEK_CONFIG.copy()
-        print(f"创建DeepSeek模型实例,使用模型: {config.get('model', 'deepseek-chat')}")
-    elif llm_model_name == "qwen":
-        config = config_module.QWEN_CONFIG.copy()
-        print(f"创建Qwen模型实例,使用模型: {config.get('model', 'qwen-plus-latest')}")
-    else:
-        raise ValueError(f"不支持的模型类型: {llm_model_name}") 
+    # 获取当前配置
+    llm_config = get_current_llm_config()
+    vector_db_config = get_current_vector_db_config()
+    model_info = get_current_model_info()
     
-    if vector_db_name == "chromadb":
+    # 获取对应的Vanna组合类
+    try:
+        if is_using_ollama_llm():
+            llm_type = "ollama"
+        else:
+            llm_type = model_info["llm_model"].lower()
+        
+        vector_db_type = model_info["vector_db"].lower()
+        
+        cls = get_vanna_class(llm_type, vector_db_type)
+        print(f"创建{llm_type.upper()}+{vector_db_type.upper()}实例")
+        
+    except ValueError as e:
+        print(f"错误: {e}")
+        print("\n可用的组合:")
+        print_available_combinations()
+        raise
+    
+    # 准备配置
+    config = llm_config.copy()
+    
+    # 配置向量数据库
+    if model_info["vector_db"] == "chromadb":
         config["path"] = os.path.dirname(os.path.abspath(__file__))
-        print(f"已配置使用ChromaDB作为向量数据库,路径:{config['path']}")
-    elif vector_db_name == "pgvector":
+        print(f"已配置使用ChromaDB,路径:{config['path']}")
+    elif model_info["vector_db"] == "pgvector":
         # 构建PostgreSQL连接字符串
-        pg_config = config_module.PGVECTOR_CONFIG
-        connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
+        connection_string = f"postgresql://{vector_db_config['user']}:{vector_db_config['password']}@{vector_db_config['host']}:{vector_db_config['port']}/{vector_db_config['dbname']}"
         config["connection_string"] = connection_string
-        print(f"已配置使用PgVector作为向量数据库,连接字符串: {connection_string}")
-    else:
-        raise ValueError(f"不支持的向量数据库类型: {vector_db_name}")    
+        print(f"已配置使用PgVector,连接字符串: {connection_string}")
     
+    # 配置embedding函数
     embedding_function = get_embedding_function()
-
     config["embedding_function"] = embedding_function
-    print(f"已配置使用 EMBEDDING_CONFIG 中的嵌入模型: {config_module.EMBEDDING_CONFIG['model_name']}, 维度: {config_module.EMBEDDING_CONFIG['embedding_dimension']}")
-    
-    key = (llm_model_name, vector_db_name)
-    cls = LLM_VECTOR_DB_MAP.get(key)
-    if cls is None:
-        raise ValueError(f"不支持的组合: 模型类型={llm_model_name}, 向量数据库类型={vector_db_name}")
+    print(f"已配置使用{model_info['embedding_type'].upper()}嵌入模型: {model_info['embedding_model']}")
     
+    # 创建实例
     vn = cls(config=config)
 
+    # 连接到业务数据库
     vn.connect_to_postgres(**config_module.APP_DB_CONFIG)           
-    print(f"连接到PostgreSQL业务数据库: "
+    print(f"连接到业务数据库: "
           f"{config_module.APP_DB_CONFIG['host']}:"
           f"{config_module.APP_DB_CONFIG['port']}/"
           f"{config_module.APP_DB_CONFIG['dbname']}")
+    
     return vn