import requests
import json
import re
from typing import List, Dict, Any, Optional
from .base_llm_chat import BaseLLMChat
class OllamaChat(BaseLLMChat):
"""Ollama AI聊天实现"""
def __init__(self, config=None):
print("...OllamaChat init...")
super().__init__(config=config)
# Ollama特定的配置参数
self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
self.timeout = config.get("timeout", 60) if config else 60
# Ollama 特定参数
self.num_ctx = config.get("num_ctx", 4096) if config else 4096 # 上下文长度
self.num_predict = config.get("num_predict", -1) if config else -1 # 预测token数量
self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1 # 重复惩罚
# 验证连接
if config and config.get("auto_check_connection", True):
self._check_ollama_health()
def _check_ollama_health(self) -> bool:
"""检查 Ollama 服务健康状态"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
if response.status_code == 200:
print(f"✅ Ollama 服务连接正常: {self.base_url}")
return True
else:
print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
print(f"❌ Ollama 服务连接失败: {e}")
return False
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")
if len(prompt) == 0:
raise Exception("Prompt is empty")
# 计算token数量估计
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
# 获取 stream 参数
stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
# 获取 enable_thinking 参数
enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
# Ollama 约束:enable_thinking=True时建议使用stream=True
# 如果stream=False但enable_thinking=True,则忽略enable_thinking
if enable_thinking and not stream_mode:
print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
enable_thinking = False
# 智能模型选择
model = self._determine_model(kwargs, enable_thinking, num_tokens)
# 检查是否为推理模型
is_reasoning_model = self._is_reasoning_model(model)
# 模型兼容性提示(但不强制切换)
if enable_thinking and not is_reasoning_model:
print(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
# 准备Ollama API请求
url = f"{self.base_url}/api/chat"
payload = {
"model": model,
"messages": prompt,
"stream": stream_mode,
"think": enable_thinking, # Ollama API 使用 think 参数控制推理功能
"options": self._build_options(kwargs, is_reasoning_model, enable_thinking)
}
try:
if stream_mode:
# 流式处理模式
if enable_thinking:
print("使用流式处理模式,启用推理功能")
else:
print("使用流式处理模式,常规聊天")
return self._handle_stream_response(url, payload, enable_thinking)
else:
# 非流式处理模式
if enable_thinking:
print("使用非流式处理模式,启用推理功能")
else:
print("使用非流式处理模式,常规聊天")
return self._handle_non_stream_response(url, payload, enable_thinking)
except requests.exceptions.RequestException as e:
print(f"Ollama API请求失败: {e}")
raise Exception(f"Ollama API调用失败: {str(e)}")
def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
"""处理流式响应"""
response = requests.post(
url,
json=payload,
timeout=self.timeout,
headers={"Content-Type": "application/json"},
stream=True
)
response.raise_for_status()
collected_content = []
for line in response.iter_lines():
if line:
try:
chunk_data = json.loads(line.decode('utf-8'))
if 'message' in chunk_data and 'content' in chunk_data['message']:
content = chunk_data['message']['content']
collected_content.append(content)
# 检查是否完成
if chunk_data.get('done', False):
break
except json.JSONDecodeError:
continue
# 合并所有内容
full_content = "".join(collected_content)
# 如果启用推理功能,尝试分离推理内容和最终答案
if enable_reasoning:
reasoning_content, final_content = self._extract_reasoning(full_content)
if reasoning_content:
print("Model reasoning process:\n", reasoning_content)
return final_content
return full_content
def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
"""处理非流式响应"""
response = requests.post(
url,
json=payload,
timeout=self.timeout,
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
result = response.json()
content = result["message"]["content"]
if enable_reasoning:
# 尝试分离推理内容和最终答案
reasoning_content, final_content = self._extract_reasoning(content)
if reasoning_content:
print("Model reasoning process:\n", reasoning_content)
return final_content
return content
def test_connection(self, test_prompt="你好") -> dict:
"""测试Ollama连接"""
result = {
"success": False,
"model": self.model,
"base_url": self.base_url,
"message": "",
"available_models": [],
"ollama_version": None
}
try:
# 检查服务健康状态
if not self._check_ollama_health():
result["message"] = "Ollama 服务不可用"
return result
# 获取可用模型列表
try:
result["available_models"] = self.list_models()
# 检查目标模型是否存在
if self.model not in result["available_models"]:
print(f"警告:模型 {self.model} 不存在,尝试拉取...")
if not self.pull_model(self.model):
result["message"] = f"模型 {self.model} 不存在且拉取失败"
return result
except Exception as e:
print(f"获取模型列表失败: {e}")
result["available_models"] = [self.model]
print(f"测试Ollama连接 - 模型: {self.model}")
print(f"Ollama服务地址: {self.base_url}")
print(f"可用模型: {', '.join(result['available_models'])}")
# 测试简单对话
prompt = [self.user_message(test_prompt)]
response = self.submit_prompt(prompt)
result["success"] = True
result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
return result
except Exception as e:
result["message"] = f"Ollama连接测试失败: {str(e)}"
return result
def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
"""智能确定使用的模型"""
# 优先级:运行时参数 > 配置文件 > 智能选择
if kwargs.get("model", None) is not None:
return kwargs.get("model")
elif kwargs.get("engine", None) is not None:
return kwargs.get("engine")
elif self.config is not None and "engine" in self.config:
return self.config["engine"]
elif self.config is not None and "model" in self.config:
return self.config["model"]
else:
# 智能选择模型
if enable_thinking:
# 优先选择推理模型
try:
available_models = self.list_models()
reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
if reasoning_models:
return reasoning_models[0] # 选择第一个推理模型
else:
print("警告:未找到推理模型,使用默认模型")
return self.model
except Exception as e:
print(f"获取模型列表时出错: {e},使用默认模型")
return self.model
else:
# 根据 token 数量选择模型
if num_tokens > 8000:
# 长文本,选择长上下文模型
try:
available_models = self.list_models()
long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
if long_context_models:
return long_context_models[0]
except Exception as e:
print(f"获取模型列表时出错: {e},使用默认模型")
return self.model
def _is_reasoning_model(self, model: str) -> bool:
"""检查是否为推理模型"""
reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
return any(keyword in model.lower() for keyword in reasoning_keywords)
def _build_options(self, kwargs: dict, is_reasoning_model: bool, enable_thinking: bool = False) -> dict:
"""构建 Ollama options 参数"""
options = {
"temperature": self.temperature,
"num_ctx": self.num_ctx,
"num_predict": self.num_predict,
"repeat_penalty": self.repeat_penalty,
}
# 过滤掉自定义参数,避免传递给 API
filtered_kwargs = {k: v for k, v in kwargs.items()
if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
# 添加其他参数到 options 中
for k, v in filtered_kwargs.items():
options[k] = v
# 推理功能参数调整(当启用推理功能时)
if enable_thinking:
# 启用推理功能时可能需要更多的预测token
if options["num_predict"] == -1:
options["num_predict"] = 2048
# 降低重复惩罚,允许更多的推理重复
options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
# 对于推理模型,可以进一步优化参数
if is_reasoning_model:
# 推理模型可能需要更长的上下文
options["num_ctx"] = max(options["num_ctx"], 8192)
return options
def _is_reasoning_content(self, content: str) -> bool:
"""判断内容是否为推理内容"""
reasoning_patterns = [
r'.*?',
r'.*?',
r'.*?',
r'思考:',
r'分析:',
r'推理:'
]
return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
def _extract_reasoning(self, content: str) -> tuple:
"""提取推理内容和最终答案"""
reasoning_patterns = [
r'(.*?)',
r'(.*?)',
r'(.*?)',
r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
]
reasoning_content = ""
final_content = content
for pattern in reasoning_patterns:
matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
if matches:
reasoning_content = "\n".join(matches)
final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
break
# 如果没有找到明确的推理标记,但内容很长,尝试简单分割
if not reasoning_content and len(content) > 500:
lines = content.split('\n')
if len(lines) > 10:
# 假设前半部分是推理,后半部分是答案
mid_point = len(lines) // 2
potential_reasoning = '\n'.join(lines[:mid_point])
potential_answer = '\n'.join(lines[mid_point:])
# 简单启发式:如果前半部分包含推理关键词,则分离
if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
reasoning_content = potential_reasoning
final_content = potential_answer
return reasoning_content, final_content
# Ollama 独特功能
def list_models(self) -> List[str]:
"""列出可用的模型"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5) # 使用较短的超时时间
response.raise_for_status()
data = response.json()
models = [model["name"] for model in data.get("models", [])]
return models if models else [self.model] # 如果没有模型,返回默认模型
except requests.exceptions.RequestException as e:
print(f"获取模型列表失败: {e}")
return [self.model] # 返回默认模型
except Exception as e:
print(f"解析模型列表失败: {e}")
return [self.model] # 返回默认模型
def pull_model(self, model_name: str) -> bool:
"""拉取模型"""
try:
print(f"正在拉取模型: {model_name}")
response = requests.post(
f"{self.base_url}/api/pull",
json={"name": model_name},
timeout=300 # 拉取模型可能需要较长时间
)
response.raise_for_status()
print(f"✅ 模型 {model_name} 拉取成功")
return True
except requests.exceptions.RequestException as e:
print(f"❌ 模型 {model_name} 拉取失败: {e}")
return False
def delete_model(self, model_name: str) -> bool:
"""删除模型"""
try:
response = requests.delete(
f"{self.base_url}/api/delete",
json={"name": model_name},
timeout=self.timeout
)
response.raise_for_status()
print(f"✅ 模型 {model_name} 删除成功")
return True
except requests.exceptions.RequestException as e:
print(f"❌ 模型 {model_name} 删除失败: {e}")
return False
def get_model_info(self, model_name: str) -> Optional[Dict]:
"""获取模型信息"""
try:
response = requests.post(
f"{self.base_url}/api/show",
json={"name": model_name},
timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"获取模型信息失败: {e}")
return None
def get_system_info(self) -> Dict:
"""获取 Ollama 系统信息"""
try:
# 获取版本信息
version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
version_info = version_response.json() if version_response.status_code == 200 else {}
# 获取模型列表
models = self.list_models()
return {
"base_url": self.base_url,
"version": version_info.get("version", "unknown"),
"available_models": models,
"current_model": self.model,
"timeout": self.timeout,
"num_ctx": self.num_ctx,
"num_predict": self.num_predict,
"repeat_penalty": self.repeat_penalty
}
except Exception as e:
return {"error": str(e)}