请分析我这个项目的代码,帮我评估一个变更,我希望添加ollama的大模型。会有两种LLM,一种是推理模型,比如qwen3/deepseek,还有一种是embedding模型。 所以我现在很困惑: 1.我需要如何在app_config.py上添加配置? 2.哪些代码需要增加或重构,比如DeepSeekChat和QianWenAI_Chat,需要根据ollama的访问方式在写一个吗? 3.帮我评估当前项目要兼容ollama,需要修改哪些内容?
通过分析你的项目代码,我来帮你评估添加Ollama支持需要的变更。你的项目架构很清晰,采用了工厂模式和配置驱动的设计,这让添加Ollama支持相对简单。
# 在app_config.py中添加Ollama配置
# 使用的模型类型("qwen" 或 "deepseek" 或 "ollama")
MODEL_TYPE = "ollama" # 新增ollama选项
# 新增:Ollama模型配置
OLLAMA_CONFIG = {
"base_url": "http://localhost:11434", # Ollama默认地址
"model": "qwen2.5:7b", # 或者其他ollama模型名称
"allow_llm_to_see_data": True,
"temperature": 0.7,
"n_results": 6,
"language": "Chinese",
"timeout": 60 # ollama可能需要更长超时时间
}
# 修改:支持Ollama作为embedding模型
EMBEDDING_CONFIG = {
"provider": "ollama", # 新增provider字段,可选值:"api", "ollama"
"model_name": "nomic-embed-text", # ollama embedding模型
"api_key": os.getenv("EMBEDDING_API_KEY"), # API方式时使用
"base_url": "http://localhost:11434", # ollama地址
"embedding_dimension": 768 # 根据实际模型调整
}
# 或者分别配置API和Ollama embedding
EMBEDDING_CONFIG_API = {
"model_name": "BAAI/bge-m3",
"api_key": os.getenv("EMBEDDING_API_KEY"),
"base_url": os.getenv("EMBEDDING_BASE_URL"),
"embedding_dimension": 1024
}
EMBEDDING_CONFIG_OLLAMA = {
"model_name": "nomic-embed-text",
"base_url": "http://localhost:11434",
"embedding_dimension": 768
}
# 选择使用哪种embedding
EMBEDDING_TYPE = "ollama" # 或 "api"
EMBEDDING_CONFIG = EMBEDDING_CONFIG_OLLAMA if EMBEDDING_TYPE == "ollama" else EMBEDDING_CONFIG_API
创建新文件 customollama/ollama_chat.py
:
import requests
import json
from vanna.base import VannaBase
from typing import List, Dict, Any
class OllamaChat(VannaBase):
def __init__(self, config=None):
print("...OllamaChat init...")
VannaBase.__init__(self, config=config)
print("传入的 config 参数如下:")
for key, value in self.config.items():
print(f" {key}: {value}")
# 默认参数
self.temperature = 0.7
self.base_url = config.get("base_url", "http://localhost:11434")
self.model = config.get("model", "qwen2.5:7b")
self.timeout = config.get("timeout", 60)
if "temperature" in config:
print(f"temperature is changed to: {config['temperature']}")
self.temperature = config["temperature"]
def system_message(self, message: str) -> any:
print(f"system_content: {message}")
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
print(f"\nuser_content: {message}")
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
print(f"assistant_content: {message}")
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
# 计算token数量估计
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
# 确定使用的模型
model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
# 准备Ollama API请求
url = f"{self.base_url}/api/chat"
payload = {
"model": model,
"messages": prompt,
"stream": False,
"options": {
"temperature": self.temperature
}
}
try:
response = requests.post(
url,
json=payload,
timeout=self.timeout,
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
result = response.json()
return result["message"]["content"]
except requests.exceptions.RequestException as e:
print(f"Ollama API请求失败: {e}")
raise Exception(f"Ollama API调用失败: {str(e)}")
def generate_sql(self, question: str, **kwargs) -> str:
"""重写generate_sql方法,增加异常处理"""
try:
print(f"[DEBUG] 尝试为问题生成SQL: {question}")
sql = super().generate_sql(question, **kwargs)
if not sql or sql.strip() == "":
print(f"[WARNING] 生成的SQL为空")
return None
# 检查返回内容是否为有效SQL
sql_lower = sql.lower().strip()
error_indicators = [
"insufficient context", "无法生成", "sorry", "cannot", "不能",
"no relevant", "no suitable", "unable to", "无法", "抱歉"
]
for indicator in error_indicators:
if indicator in sql_lower:
print(f"[WARNING] LLM返回错误信息而非SQL: {sql}")
return None
sql_keywords = ["select", "insert", "update", "delete", "with", "from", "where"]
if not any(keyword in sql_lower for keyword in sql_keywords):
print(f"[WARNING] 返回内容不像有效SQL: {sql}")
return None
print(f"[SUCCESS] 成功生成SQL: {sql}")
return sql
except Exception as e:
print(f"[ERROR] SQL生成过程中出现异常: {str(e)}")
return None
def generate_question(self, sql: str, **kwargs) -> str:
"""根据SQL生成中文问题"""
prompt = [
self.system_message(
"请你根据下方SQL语句推测用户的业务提问,只返回清晰的自然语言问题,不要包含任何解释或SQL内容,也不要出现表名,问题要使用中文,并以问号结尾。"
),
self.user_message(sql)
]
response = self.submit_prompt(prompt, **kwargs)
return response
def chat_with_llm(self, question: str, **kwargs) -> str:
"""直接与LLM对话"""
try:
prompt = [
self.system_message(
"你是一个友好的AI助手。如果用户询问的是数据库相关问题,请建议他们重新表述问题以便进行SQL查询。对于其他问题,请尽力提供有帮助的回答。"
),
self.user_message(question)
]
response = self.submit_prompt(prompt, **kwargs)
return response
except Exception as e:
print(f"[ERROR] LLM对话失败: {str(e)}")
return f"抱歉,我暂时无法回答您的问题。请稍后再试。"
创建 customollama/__init__.py
:
from .ollama_chat import OllamaChat
修改 embedding_function.py
,添加Ollama支持:
import requests
import time
import numpy as np
from typing import List, Callable
class OllamaEmbeddingFunction:
def __init__(self, model_name: str, base_url: str, embedding_dimension: int):
self.model_name = model_name
self.base_url = base_url
self.embedding_dimension = embedding_dimension
self.max_retries = 2
self.retry_interval = 2
def __call__(self, input) -> List[List[float]]:
"""为文本列表生成嵌入向量"""
if not isinstance(input, list):
input = [input]
embeddings = []
for text in input:
try:
embedding = self.generate_embedding(text)
embeddings.append(embedding)
except Exception as e:
print(f"获取embedding时出错: {e}")
embeddings.append([0.0] * self.embedding_dimension)
return embeddings
def generate_embedding(self, text: str) -> List[float]:
"""为单个文本生成嵌入向量"""
print(f"生成Ollama嵌入向量,文本长度: {len(text)} 字符")
if not text or len(text.strip()) == 0:
print("输入文本为空,返回零向量")
return [0.0] * self.embedding_dimension
url = f"{self.base_url}/api/embeddings"
payload = {
"model": self.model_name,
"prompt": text
}
retries = 0
while retries <= self.max_retries:
try:
response = requests.post(
url,
json=payload,
timeout=30
)
if response.status_code != 200:
error_msg = f"Ollama API请求错误: {response.status_code}, {response.text}"
print(error_msg)
if response.status_code in (429, 500, 502, 503, 504):
retries += 1
if retries <= self.max_retries:
wait_time = self.retry_interval * (2 ** (retries - 1))
print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
time.sleep(wait_time)
continue
raise ValueError(error_msg)
result = response.json()
if "embedding" in result:
vector = result["embedding"]
# 验证向量维度
actual_dim = len(vector)
if actual_dim != self.embedding_dimension:
print(f"向量维度不匹配: 期望 {self.embedding_dimension}, 实际 {actual_dim}")
# 如果维度不匹配,可以选择截断或填充
if actual_dim > self.embedding_dimension:
vector = vector[:self.embedding_dimension]
else:
vector.extend([0.0] * (self.embedding_dimension - actual_dim))
print(f"成功生成Ollama embedding向量,维度: {len(vector)}")
return vector
else:
error_msg = f"Ollama API返回格式异常: {result}"
print(error_msg)
raise ValueError(error_msg)
except Exception as e:
print(f"生成Ollama embedding时出错: {str(e)}")
retries += 1
if retries <= self.max_retries:
wait_time = self.retry_interval * (2 ** (retries - 1))
print(f"等待 {wait_time} 秒后重试 ({retries}/{self.max_retries})")
time.sleep(wait_time)
else:
print(f"已达到最大重试次数 ({self.max_retries}),生成embedding失败")
return [0.0] * self.embedding_dimension
raise RuntimeError("生成Ollama embedding失败")
def test_connection(self, test_text="测试文本") -> dict:
"""测试Ollama嵌入模型的连接"""
result = {
"success": False,
"model": self.model_name,
"base_url": self.base_url,
"message": "",
"actual_dimension": None,
"expected_dimension": self.embedding_dimension
}
try:
print(f"测试Ollama嵌入模型连接 - 模型: {self.model_name}")
print(f"Ollama服务地址: {self.base_url}")
vector = self.generate_embedding(test_text)
actual_dimension = len(vector)
result["success"] = True
result["actual_dimension"] = actual_dimension
if actual_dimension != self.embedding_dimension:
result["message"] = f"警告: 模型实际生成的向量维度({actual_dimension})与配置维度({self.embedding_dimension})不一致"
else:
result["message"] = f"Ollama连接测试成功,向量维度: {actual_dimension}"
return result
except Exception as e:
result["message"] = f"Ollama连接测试失败: {str(e)}"
return result
# 修改原有的get_embedding_function函数
def get_embedding_function():
"""根据配置创建合适的EmbeddingFunction实例"""
try:
import app_config
except ImportError:
raise ImportError("无法导入 app_config.py。")
try:
embedding_config = app_config.EMBEDDING_CONFIG
except AttributeError:
raise AttributeError("app_config.py 中缺少 EMBEDDING_CONFIG 配置。")
# 检查是否使用Ollama embedding
provider = embedding_config.get("provider", "api")
if provider == "ollama":
print("使用Ollama Embedding模型")
return OllamaEmbeddingFunction(
model_name=embedding_config["model_name"],
base_url=embedding_config["base_url"],
embedding_dimension=embedding_config["embedding_dimension"]
)
else:
print("使用API Embedding模型")
# 原有的EmbeddingFunction逻辑
api_key = embedding_config["api_key"]
model_name = embedding_config["model_name"]
base_url = embedding_config["base_url"]
embedding_dimension = embedding_config["embedding_dimension"]
if api_key is None:
raise KeyError("EMBEDDING_CONFIG 中的 'api_key' 未设置。")
return EmbeddingFunction(
model_name=model_name,
api_key=api_key,
base_url=base_url,
embedding_dimension=embedding_dimension
)
# 修改测试函数
def test_embedding_connection() -> dict:
"""测试嵌入模型连接"""
try:
embedding_function = get_embedding_function()
test_result = embedding_function.test_connection()
if test_result["success"]:
print(f"嵌入模型连接测试成功!")
if "警告" in test_result["message"]:
print(test_result["message"])
else:
print(f"嵌入模型连接测试失败: {test_result['message']}")
return test_result
except Exception as e:
error_message = f"无法测试嵌入模型连接: {str(e)}"
print(error_message)
return {
"success": False,
"message": error_message
}
"""
Vanna LLM 工厂文件,支持多种LLM提供商
"""
from vanna.chromadb import ChromaDB_VectorStore
from customqianwen.Custom_QianwenAI_chat import QianWenAI_Chat
from customdeepseek.custom_deepseek_chat import DeepSeekChat
from customollama.ollama_chat import OllamaChat # 新增
import app_config
from core.embedding_function import get_embedding_function
import os
class Vanna_Qwen_ChromaDB(ChromaDB_VectorStore, QianWenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
QianWenAI_Chat.__init__(self, config=config)
class Vanna_DeepSeek_ChromaDB(ChromaDB_VectorStore, DeepSeekChat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
DeepSeekChat.__init__(self, config=config)
# 新增Ollama支持
class Vanna_Ollama_ChromaDB(ChromaDB_VectorStore, OllamaChat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OllamaChat.__init__(self, config=config)
def create_vanna_instance(config_module=None):
"""工厂函数:创建并初始化Vanna实例"""
if config_module is None:
config_module = app_config
model_type = config_module.MODEL_TYPE.lower()
config = {}
if model_type == "deepseek":
config = config_module.API_DEEPSEEK_CONFIG.copy()
print(f"创建DeepSeek模型实例,使用模型: {config['model']}")
if not config.get("api_key"):
print(f"\n错误: DeepSeek API密钥未设置或为空")
print(f"请在.env文件中设置DEEPSEEK_API_KEY环境变量")
import sys
sys.exit(1)
elif model_type == "qwen":
config = config_module.API_QWEN_CONFIG.copy()
print(f"创建Qwen模型实例,使用模型: {config['model']}")
if not config.get("api_key"):
print(f"\n错误: Qwen API密钥未设置或为空")
print(f"请在.env文件中设置QWEN_API_KEY环境变量")
import sys
sys.exit(1)
elif model_type == "ollama": # 新增
config = config_module.OLLAMA_CONFIG.copy()
print(f"创建Ollama模型实例,使用模型: {config['model']}")
print(f"Ollama服务地址: {config['base_url']}")
# Ollama通常不需要API密钥,但可以检查服务是否可用
try:
import requests
response = requests.get(f"{config['base_url']}/api/tags", timeout=5)
if response.status_code != 200:
print(f"警告: 无法连接到Ollama服务 ({config['base_url']})")
except Exception as e:
print(f"警告: Ollama服务连接测试失败: {e}")
else:
raise ValueError(f"不支持的模型类型: {model_type}")
# 获取embedding函数(支持API和Ollama两种方式)
embedding_function = get_embedding_function()
config["embedding_function"] = embedding_function
# 打印embedding配置信息
embedding_config = config_module.EMBEDDING_CONFIG
provider = embedding_config.get("provider", "api")
print(f"已配置使用 {provider.upper()} 嵌入模型: {embedding_config['model_name']}, 维度: {embedding_config['embedding_dimension']}")
# 设置ChromaDB路径
project_root = os.path.dirname(os.path.abspath(__file__))
config["path"] = project_root
print(f"已配置使用ChromaDB作为向量数据库,路径:{project_root}")
# 创建对应的Vanna实例
vn = None
if model_type == "deepseek":
vn = Vanna_DeepSeek_ChromaDB(config=config)
print("创建DeepSeek+ChromaDB实例")
elif model_type == "qwen":
vn = Vanna_Qwen_ChromaDB(config=config)
print("创建Qwen+ChromaDB实例")
elif model_type == "ollama": # 新增
vn = Vanna_Ollama_ChromaDB(config=config)
print("创建Ollama+ChromaDB实例")
if vn is None:
raise ValueError(f"未能成功创建Vanna实例,不支持的模型类型: {model_type}")
# 连接到业务数据库
vn.connect_to_postgres(**config_module.APP_DB_CONFIG)
print(f"已连接到业务数据库: "
f"{config_module.APP_DB_CONFIG['host']}:"
f"{config_module.APP_DB_CONFIG['port']}/"
f"{config_module.APP_DB_CONFIG['dbname']}")
return vn
项目根目录/
├── customollama/
│ ├── __init__.py
│ └── ollama_chat.py
在你的项目中使用Ollama:
# 在app_config.py中设置
MODEL_TYPE = "ollama"
OLLAMA_CONFIG = {
"base_url": "http://localhost:11434",
"model": "qwen2.5:7b", # 或其他模型
# ... 其他配置
}
# 在app_config.py中设置
EMBEDDING_CONFIG = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"base_url": "http://localhost:11434",
"embedding_dimension": 768
}
通过以上修改,你的项目就可以支持Ollama了。主要变更包括:
app_config.py
中添加Ollama相关配置OllamaChat
类embedding_function.py
支持Ollama embeddingvanna_llm_factory.py
中添加Ollama支持这种设计保持了现有代码的兼容性,你可以通过配置轻松切换不同的LLM提供商。