|
@@ -1,6 +1,7 @@
|
|
import requests
|
|
import requests
|
|
import json
|
|
import json
|
|
-from typing import List, Dict, Any
|
|
|
|
|
|
+import re
|
|
|
|
+from typing import List, Dict, Any, Optional
|
|
from .base_llm_chat import BaseLLMChat
|
|
from .base_llm_chat import BaseLLMChat
|
|
|
|
|
|
|
|
|
|
@@ -12,9 +13,32 @@ class OllamaChat(BaseLLMChat):
|
|
super().__init__(config=config)
|
|
super().__init__(config=config)
|
|
|
|
|
|
# Ollama特定的配置参数
|
|
# Ollama特定的配置参数
|
|
- self.base_url = config.get("base_url", "http://localhost:11434")
|
|
|
|
- self.model = config.get("model", "qwen2.5:7b")
|
|
|
|
- self.timeout = config.get("timeout", 60)
|
|
|
|
|
|
+ self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
|
|
|
|
+ self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
|
|
|
|
+ self.timeout = config.get("timeout", 60) if config else 60
|
|
|
|
+
|
|
|
|
+ # Ollama 特定参数
|
|
|
|
+ self.num_ctx = config.get("num_ctx", 4096) if config else 4096 # 上下文长度
|
|
|
|
+ self.num_predict = config.get("num_predict", -1) if config else -1 # 预测token数量
|
|
|
|
+ self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1 # 重复惩罚
|
|
|
|
+
|
|
|
|
+ # 验证连接
|
|
|
|
+ if config and config.get("auto_check_connection", True):
|
|
|
|
+ self._check_ollama_health()
|
|
|
|
+
|
|
|
|
+ def _check_ollama_health(self) -> bool:
|
|
|
|
+ """检查 Ollama 服务健康状态"""
|
|
|
|
+ try:
|
|
|
|
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
|
|
|
+ if response.status_code == 200:
|
|
|
|
+ print(f"✅ Ollama 服务连接正常: {self.base_url}")
|
|
|
|
+ return True
|
|
|
|
+ else:
|
|
|
|
+ print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
|
|
|
|
+ return False
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
+ print(f"❌ Ollama 服务连接失败: {e}")
|
|
|
|
+ return False
|
|
|
|
|
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
def submit_prompt(self, prompt, **kwargs) -> str:
|
|
if prompt is None:
|
|
if prompt is None:
|
|
@@ -28,38 +52,127 @@ class OllamaChat(BaseLLMChat):
|
|
for message in prompt:
|
|
for message in prompt:
|
|
num_tokens += len(message["content"]) / 4
|
|
num_tokens += len(message["content"]) / 4
|
|
|
|
|
|
- # 确定使用的模型
|
|
|
|
- model = kwargs.get("model") or kwargs.get("engine") or self.config.get("model") or self.model
|
|
|
|
|
|
+ # 获取 stream 参数
|
|
|
|
+ stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
|
|
|
|
+
|
|
|
|
+ # 获取 enable_thinking 参数
|
|
|
|
+ enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
|
|
|
|
+
|
|
|
|
+ # Ollama 约束:enable_thinking=True时建议使用stream=True
|
|
|
|
+ # 如果stream=False但enable_thinking=True,则忽略enable_thinking
|
|
|
|
+ if enable_thinking and not stream_mode:
|
|
|
|
+ print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
|
|
|
|
+ enable_thinking = False
|
|
|
|
+
|
|
|
|
+ # 智能模型选择
|
|
|
|
+ model = self._determine_model(kwargs, enable_thinking, num_tokens)
|
|
|
|
+
|
|
|
|
+ # 检查是否为推理模型
|
|
|
|
+ is_reasoning_model = self._is_reasoning_model(model)
|
|
|
|
+
|
|
|
|
+ # 模型兼容性提示(但不强制切换)
|
|
|
|
+ if enable_thinking and not is_reasoning_model:
|
|
|
|
+ print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
|
|
|
|
|
|
print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
|
|
print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
|
|
|
|
+ print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
|
|
|
|
|
|
# 准备Ollama API请求
|
|
# 准备Ollama API请求
|
|
url = f"{self.base_url}/api/chat"
|
|
url = f"{self.base_url}/api/chat"
|
|
payload = {
|
|
payload = {
|
|
"model": model,
|
|
"model": model,
|
|
"messages": prompt,
|
|
"messages": prompt,
|
|
- "stream": False,
|
|
|
|
- "options": {
|
|
|
|
- "temperature": self.temperature
|
|
|
|
- }
|
|
|
|
|
|
+ "stream": stream_mode,
|
|
|
|
+ "options": self._build_options(kwargs, is_reasoning_model)
|
|
}
|
|
}
|
|
|
|
|
|
try:
|
|
try:
|
|
- response = requests.post(
|
|
|
|
- url,
|
|
|
|
- json=payload,
|
|
|
|
- timeout=self.timeout,
|
|
|
|
- headers={"Content-Type": "application/json"}
|
|
|
|
- )
|
|
|
|
- response.raise_for_status()
|
|
|
|
-
|
|
|
|
- result = response.json()
|
|
|
|
- return result["message"]["content"]
|
|
|
|
-
|
|
|
|
|
|
+ if stream_mode:
|
|
|
|
+ # 流式处理模式
|
|
|
|
+ if is_reasoning_model and enable_thinking:
|
|
|
|
+ print("使用流式处理模式,启用推理功能")
|
|
|
|
+ else:
|
|
|
|
+ print("使用流式处理模式,常规聊天")
|
|
|
|
+
|
|
|
|
+ return self._handle_stream_response(url, payload, is_reasoning_model and enable_thinking)
|
|
|
|
+ else:
|
|
|
|
+ # 非流式处理模式
|
|
|
|
+ if is_reasoning_model and enable_thinking:
|
|
|
|
+ print("使用非流式处理模式,启用推理功能")
|
|
|
|
+ else:
|
|
|
|
+ print("使用非流式处理模式,常规聊天")
|
|
|
|
+
|
|
|
|
+ return self._handle_non_stream_response(url, payload, is_reasoning_model and enable_thinking)
|
|
|
|
+
|
|
except requests.exceptions.RequestException as e:
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Ollama API请求失败: {e}")
|
|
print(f"Ollama API请求失败: {e}")
|
|
raise Exception(f"Ollama API调用失败: {str(e)}")
|
|
raise Exception(f"Ollama API调用失败: {str(e)}")
|
|
|
|
|
|
|
|
+ def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
|
|
|
|
+ """处理流式响应"""
|
|
|
|
+ response = requests.post(
|
|
|
|
+ url,
|
|
|
|
+ json=payload,
|
|
|
|
+ timeout=self.timeout,
|
|
|
|
+ headers={"Content-Type": "application/json"},
|
|
|
|
+ stream=True
|
|
|
|
+ )
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+
|
|
|
|
+ collected_content = []
|
|
|
|
+
|
|
|
|
+ for line in response.iter_lines():
|
|
|
|
+ if line:
|
|
|
|
+ try:
|
|
|
|
+ chunk_data = json.loads(line.decode('utf-8'))
|
|
|
|
+
|
|
|
|
+ if 'message' in chunk_data and 'content' in chunk_data['message']:
|
|
|
|
+ content = chunk_data['message']['content']
|
|
|
|
+ collected_content.append(content)
|
|
|
|
+
|
|
|
|
+ # 检查是否完成
|
|
|
|
+ if chunk_data.get('done', False):
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 合并所有内容
|
|
|
|
+ full_content = "".join(collected_content)
|
|
|
|
+
|
|
|
|
+ # 如果启用推理功能,尝试分离推理内容和最终答案
|
|
|
|
+ if enable_reasoning:
|
|
|
|
+ reasoning_content, final_content = self._extract_reasoning(full_content)
|
|
|
|
+
|
|
|
|
+ if reasoning_content:
|
|
|
|
+ print("Model reasoning process:\n", reasoning_content)
|
|
|
|
+ return final_content
|
|
|
|
+
|
|
|
|
+ return full_content
|
|
|
|
+
|
|
|
|
+ def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
|
|
|
|
+ """处理非流式响应"""
|
|
|
|
+ response = requests.post(
|
|
|
|
+ url,
|
|
|
|
+ json=payload,
|
|
|
|
+ timeout=self.timeout,
|
|
|
|
+ headers={"Content-Type": "application/json"}
|
|
|
|
+ )
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+
|
|
|
|
+ result = response.json()
|
|
|
|
+ content = result["message"]["content"]
|
|
|
|
+
|
|
|
|
+ if enable_reasoning:
|
|
|
|
+ # 尝试分离推理内容和最终答案
|
|
|
|
+ reasoning_content, final_content = self._extract_reasoning(content)
|
|
|
|
+
|
|
|
|
+ if reasoning_content:
|
|
|
|
+ print("Model reasoning process:\n", reasoning_content)
|
|
|
|
+ return final_content
|
|
|
|
+
|
|
|
|
+ return content
|
|
|
|
+
|
|
def test_connection(self, test_prompt="你好") -> dict:
|
|
def test_connection(self, test_prompt="你好") -> dict:
|
|
"""测试Ollama连接"""
|
|
"""测试Ollama连接"""
|
|
result = {
|
|
result = {
|
|
@@ -67,11 +180,33 @@ class OllamaChat(BaseLLMChat):
|
|
"model": self.model,
|
|
"model": self.model,
|
|
"base_url": self.base_url,
|
|
"base_url": self.base_url,
|
|
"message": "",
|
|
"message": "",
|
|
|
|
+ "available_models": [],
|
|
|
|
+ "ollama_version": None
|
|
}
|
|
}
|
|
|
|
|
|
try:
|
|
try:
|
|
|
|
+ # 检查服务健康状态
|
|
|
|
+ if not self._check_ollama_health():
|
|
|
|
+ result["message"] = "Ollama 服务不可用"
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+ # 获取可用模型列表
|
|
|
|
+ try:
|
|
|
|
+ result["available_models"] = self.list_models()
|
|
|
|
+
|
|
|
|
+ # 检查目标模型是否存在
|
|
|
|
+ if self.model not in result["available_models"]:
|
|
|
|
+ print(f"警告:模型 {self.model} 不存在,尝试拉取...")
|
|
|
|
+ if not self.pull_model(self.model):
|
|
|
|
+ result["message"] = f"模型 {self.model} 不存在且拉取失败"
|
|
|
|
+ return result
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"获取模型列表失败: {e}")
|
|
|
|
+ result["available_models"] = [self.model]
|
|
|
|
+
|
|
print(f"测试Ollama连接 - 模型: {self.model}")
|
|
print(f"测试Ollama连接 - 模型: {self.model}")
|
|
print(f"Ollama服务地址: {self.base_url}")
|
|
print(f"Ollama服务地址: {self.base_url}")
|
|
|
|
+ print(f"可用模型: {', '.join(result['available_models'])}")
|
|
|
|
|
|
# 测试简单对话
|
|
# 测试简单对话
|
|
prompt = [self.user_message(test_prompt)]
|
|
prompt = [self.user_message(test_prompt)]
|
|
@@ -84,4 +219,209 @@ class OllamaChat(BaseLLMChat):
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
result["message"] = f"Ollama连接测试失败: {str(e)}"
|
|
result["message"] = f"Ollama连接测试失败: {str(e)}"
|
|
- return result
|
|
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+ def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
|
|
|
|
+ """智能确定使用的模型"""
|
|
|
|
+ # 优先级:运行时参数 > 配置文件 > 智能选择
|
|
|
|
+ if kwargs.get("model", None) is not None:
|
|
|
|
+ return kwargs.get("model")
|
|
|
|
+ elif kwargs.get("engine", None) is not None:
|
|
|
|
+ return kwargs.get("engine")
|
|
|
|
+ elif self.config is not None and "engine" in self.config:
|
|
|
|
+ return self.config["engine"]
|
|
|
|
+ elif self.config is not None and "model" in self.config:
|
|
|
|
+ return self.config["model"]
|
|
|
|
+ else:
|
|
|
|
+ # 智能选择模型
|
|
|
|
+ if enable_thinking:
|
|
|
|
+ # 优先选择推理模型
|
|
|
|
+ try:
|
|
|
|
+ available_models = self.list_models()
|
|
|
|
+ reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
|
|
|
|
+ if reasoning_models:
|
|
|
|
+ return reasoning_models[0] # 选择第一个推理模型
|
|
|
|
+ else:
|
|
|
|
+ print("警告:未找到推理模型,使用默认模型")
|
|
|
|
+ return self.model
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"获取模型列表时出错: {e},使用默认模型")
|
|
|
|
+ return self.model
|
|
|
|
+ else:
|
|
|
|
+ # 根据 token 数量选择模型
|
|
|
|
+ if num_tokens > 8000:
|
|
|
|
+ # 长文本,选择长上下文模型
|
|
|
|
+ try:
|
|
|
|
+ available_models = self.list_models()
|
|
|
|
+ long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
|
|
|
|
+ if long_context_models:
|
|
|
|
+ return long_context_models[0]
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"获取模型列表时出错: {e},使用默认模型")
|
|
|
|
+
|
|
|
|
+ return self.model
|
|
|
|
+
|
|
|
|
+ def _is_reasoning_model(self, model: str) -> bool:
|
|
|
|
+ """检查是否为推理模型"""
|
|
|
|
+ reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
|
|
|
|
+ return any(keyword in model.lower() for keyword in reasoning_keywords)
|
|
|
|
+
|
|
|
|
+ def _build_options(self, kwargs: dict, is_reasoning_model: bool) -> dict:
|
|
|
|
+ """构建 Ollama options 参数"""
|
|
|
|
+ options = {
|
|
|
|
+ "temperature": self.temperature,
|
|
|
|
+ "num_ctx": self.num_ctx,
|
|
|
|
+ "num_predict": self.num_predict,
|
|
|
|
+ "repeat_penalty": self.repeat_penalty,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # 过滤掉自定义参数,避免传递给 API
|
|
|
|
+ filtered_kwargs = {k: v for k, v in kwargs.items()
|
|
|
|
+ if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
|
|
|
|
+
|
|
|
|
+ # 添加其他参数到 options 中
|
|
|
|
+ for k, v in filtered_kwargs.items():
|
|
|
|
+ options[k] = v
|
|
|
|
+
|
|
|
|
+ # 推理模型特定参数调整
|
|
|
|
+ if is_reasoning_model:
|
|
|
|
+ # 推理模型可能需要更多的预测token
|
|
|
|
+ if options["num_predict"] == -1:
|
|
|
|
+ options["num_predict"] = 2048
|
|
|
|
+ # 降低重复惩罚,允许更多的推理重复
|
|
|
|
+ options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
|
|
|
|
+
|
|
|
|
+ return options
|
|
|
|
+
|
|
|
|
+ def _is_reasoning_content(self, content: str) -> bool:
|
|
|
|
+ """判断内容是否为推理内容"""
|
|
|
|
+ reasoning_patterns = [
|
|
|
|
+ r'<think>.*?</think>',
|
|
|
|
+ r'<reasoning>.*?</reasoning>',
|
|
|
|
+ r'<analysis>.*?</analysis>',
|
|
|
|
+ r'思考:',
|
|
|
|
+ r'分析:',
|
|
|
|
+ r'推理:'
|
|
|
|
+ ]
|
|
|
|
+ return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
|
|
|
|
+
|
|
|
|
+ def _extract_reasoning(self, content: str) -> tuple:
|
|
|
|
+ """提取推理内容和最终答案"""
|
|
|
|
+ reasoning_patterns = [
|
|
|
|
+ r'<think>(.*?)</think>',
|
|
|
|
+ r'<reasoning>(.*?)</reasoning>',
|
|
|
|
+ r'<analysis>(.*?)</analysis>',
|
|
|
|
+ r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
|
|
|
|
+ r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
|
|
|
|
+ r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ reasoning_content = ""
|
|
|
|
+ final_content = content
|
|
|
|
+
|
|
|
|
+ for pattern in reasoning_patterns:
|
|
|
|
+ matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
|
|
|
|
+ if matches:
|
|
|
|
+ reasoning_content = "\n".join(matches)
|
|
|
|
+ final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ # 如果没有找到明确的推理标记,但内容很长,尝试简单分割
|
|
|
|
+ if not reasoning_content and len(content) > 500:
|
|
|
|
+ lines = content.split('\n')
|
|
|
|
+ if len(lines) > 10:
|
|
|
|
+ # 假设前半部分是推理,后半部分是答案
|
|
|
|
+ mid_point = len(lines) // 2
|
|
|
|
+ potential_reasoning = '\n'.join(lines[:mid_point])
|
|
|
|
+ potential_answer = '\n'.join(lines[mid_point:])
|
|
|
|
+
|
|
|
|
+ # 简单启发式:如果前半部分包含推理关键词,则分离
|
|
|
|
+ if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
|
|
|
|
+ reasoning_content = potential_reasoning
|
|
|
|
+ final_content = potential_answer
|
|
|
|
+
|
|
|
|
+ return reasoning_content, final_content
|
|
|
|
+
|
|
|
|
+ # Ollama 独特功能
|
|
|
|
+ def list_models(self) -> List[str]:
|
|
|
|
+ """列出可用的模型"""
|
|
|
|
+ try:
|
|
|
|
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5) # 使用较短的超时时间
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+ data = response.json()
|
|
|
|
+ models = [model["name"] for model in data.get("models", [])]
|
|
|
|
+ return models if models else [self.model] # 如果没有模型,返回默认模型
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
+ print(f"获取模型列表失败: {e}")
|
|
|
|
+ return [self.model] # 返回默认模型
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"解析模型列表失败: {e}")
|
|
|
|
+ return [self.model] # 返回默认模型
|
|
|
|
+
|
|
|
|
+ def pull_model(self, model_name: str) -> bool:
|
|
|
|
+ """拉取模型"""
|
|
|
|
+ try:
|
|
|
|
+ print(f"正在拉取模型: {model_name}")
|
|
|
|
+ response = requests.post(
|
|
|
|
+ f"{self.base_url}/api/pull",
|
|
|
|
+ json={"name": model_name},
|
|
|
|
+ timeout=300 # 拉取模型可能需要较长时间
|
|
|
|
+ )
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+ print(f"✅ 模型 {model_name} 拉取成功")
|
|
|
|
+ return True
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
+ print(f"❌ 模型 {model_name} 拉取失败: {e}")
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ def delete_model(self, model_name: str) -> bool:
|
|
|
|
+ """删除模型"""
|
|
|
|
+ try:
|
|
|
|
+ response = requests.delete(
|
|
|
|
+ f"{self.base_url}/api/delete",
|
|
|
|
+ json={"name": model_name},
|
|
|
|
+ timeout=self.timeout
|
|
|
|
+ )
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+ print(f"✅ 模型 {model_name} 删除成功")
|
|
|
|
+ return True
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
+ print(f"❌ 模型 {model_name} 删除失败: {e}")
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ def get_model_info(self, model_name: str) -> Optional[Dict]:
|
|
|
|
+ """获取模型信息"""
|
|
|
|
+ try:
|
|
|
|
+ response = requests.post(
|
|
|
|
+ f"{self.base_url}/api/show",
|
|
|
|
+ json={"name": model_name},
|
|
|
|
+ timeout=self.timeout
|
|
|
|
+ )
|
|
|
|
+ response.raise_for_status()
|
|
|
|
+ return response.json()
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
+ print(f"获取模型信息失败: {e}")
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ def get_system_info(self) -> Dict:
|
|
|
|
+ """获取 Ollama 系统信息"""
|
|
|
|
+ try:
|
|
|
|
+ # 获取版本信息
|
|
|
|
+ version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
|
|
|
|
+ version_info = version_response.json() if version_response.status_code == 200 else {}
|
|
|
|
+
|
|
|
|
+ # 获取模型列表
|
|
|
|
+ models = self.list_models()
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "base_url": self.base_url,
|
|
|
|
+ "version": version_info.get("version", "unknown"),
|
|
|
|
+ "available_models": models,
|
|
|
|
+ "current_model": self.model,
|
|
|
|
+ "timeout": self.timeout,
|
|
|
|
+ "num_ctx": self.num_ctx,
|
|
|
|
+ "num_predict": self.num_predict,
|
|
|
|
+ "repeat_penalty": self.repeat_penalty
|
|
|
|
+ }
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return {"error": str(e)}
|