123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- import requests
- import json
- import re
- from typing import List, Dict, Any, Optional
- from .base_llm_chat import BaseLLMChat
- class OllamaChat(BaseLLMChat):
- """Ollama AI聊天实现"""
-
- def __init__(self, config=None):
- print("...OllamaChat init...")
- super().__init__(config=config)
- # Ollama特定的配置参数
- self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
- self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
- self.timeout = config.get("timeout", 60) if config else 60
-
- # Ollama 特定参数
- self.num_ctx = config.get("num_ctx", 4096) if config else 4096 # 上下文长度
- self.num_predict = config.get("num_predict", -1) if config else -1 # 预测token数量
- self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1 # 重复惩罚
-
- # 验证连接
- if config and config.get("auto_check_connection", True):
- self._check_ollama_health()
- def _check_ollama_health(self) -> bool:
- """检查 Ollama 服务健康状态"""
- try:
- response = requests.get(f"{self.base_url}/api/tags", timeout=5)
- if response.status_code == 200:
- print(f"✅ Ollama 服务连接正常: {self.base_url}")
- return True
- else:
- print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
- return False
- except requests.exceptions.RequestException as e:
- print(f"❌ Ollama 服务连接失败: {e}")
- return False
- def submit_prompt(self, prompt, **kwargs) -> str:
- if prompt is None:
- raise Exception("Prompt is None")
- if len(prompt) == 0:
- raise Exception("Prompt is empty")
- # 计算token数量估计
- num_tokens = 0
- for message in prompt:
- num_tokens += len(message["content"]) / 4
- # 获取 stream 参数
- stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
-
- # 获取 enable_thinking 参数
- enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
- # Ollama 约束:enable_thinking=True时建议使用stream=True
- # 如果stream=False但enable_thinking=True,则忽略enable_thinking
- if enable_thinking and not stream_mode:
- print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
- enable_thinking = False
- # 智能模型选择
- model = self._determine_model(kwargs, enable_thinking, num_tokens)
- # 检查是否为推理模型
- is_reasoning_model = self._is_reasoning_model(model)
-
- # 模型兼容性提示(但不强制切换)
- if enable_thinking and not is_reasoning_model:
- print(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
- print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
- print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
- # 准备Ollama API请求
- url = f"{self.base_url}/api/chat"
- payload = {
- "model": model,
- "messages": prompt,
- "stream": stream_mode,
- "think": enable_thinking, # Ollama API 使用 think 参数控制推理功能
- "options": self._build_options(kwargs, is_reasoning_model, enable_thinking)
- }
- try:
- if stream_mode:
- # 流式处理模式
- if enable_thinking:
- print("使用流式处理模式,启用推理功能")
- else:
- print("使用流式处理模式,常规聊天")
-
- return self._handle_stream_response(url, payload, enable_thinking)
- else:
- # 非流式处理模式
- if enable_thinking:
- print("使用非流式处理模式,启用推理功能")
- else:
- print("使用非流式处理模式,常规聊天")
-
- return self._handle_non_stream_response(url, payload, enable_thinking)
-
- except requests.exceptions.RequestException as e:
- print(f"Ollama API请求失败: {e}")
- raise Exception(f"Ollama API调用失败: {str(e)}")
- def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
- """处理流式响应"""
- response = requests.post(
- url,
- json=payload,
- timeout=self.timeout,
- headers={"Content-Type": "application/json"},
- stream=True
- )
- response.raise_for_status()
-
- collected_content = []
-
- for line in response.iter_lines():
- if line:
- try:
- chunk_data = json.loads(line.decode('utf-8'))
-
- if 'message' in chunk_data and 'content' in chunk_data['message']:
- content = chunk_data['message']['content']
- collected_content.append(content)
-
- # 检查是否完成
- if chunk_data.get('done', False):
- break
-
- except json.JSONDecodeError:
- continue
-
- # 合并所有内容
- full_content = "".join(collected_content)
-
- # 如果启用推理功能,尝试分离推理内容和最终答案
- if enable_reasoning:
- reasoning_content, final_content = self._extract_reasoning(full_content)
-
- if reasoning_content:
- print("Model reasoning process:\n", reasoning_content)
- return final_content
-
- return full_content
- def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
- """处理非流式响应"""
- response = requests.post(
- url,
- json=payload,
- timeout=self.timeout,
- headers={"Content-Type": "application/json"}
- )
- response.raise_for_status()
-
- result = response.json()
- content = result["message"]["content"]
-
- if enable_reasoning:
- # 尝试分离推理内容和最终答案
- reasoning_content, final_content = self._extract_reasoning(content)
-
- if reasoning_content:
- print("Model reasoning process:\n", reasoning_content)
- return final_content
-
- return content
- def test_connection(self, test_prompt="你好") -> dict:
- """测试Ollama连接"""
- result = {
- "success": False,
- "model": self.model,
- "base_url": self.base_url,
- "message": "",
- "available_models": [],
- "ollama_version": None
- }
-
- try:
- # 检查服务健康状态
- if not self._check_ollama_health():
- result["message"] = "Ollama 服务不可用"
- return result
-
- # 获取可用模型列表
- try:
- result["available_models"] = self.list_models()
-
- # 检查目标模型是否存在
- if self.model not in result["available_models"]:
- print(f"警告:模型 {self.model} 不存在,尝试拉取...")
- if not self.pull_model(self.model):
- result["message"] = f"模型 {self.model} 不存在且拉取失败"
- return result
- except Exception as e:
- print(f"获取模型列表失败: {e}")
- result["available_models"] = [self.model]
-
- print(f"测试Ollama连接 - 模型: {self.model}")
- print(f"Ollama服务地址: {self.base_url}")
- print(f"可用模型: {', '.join(result['available_models'])}")
-
- # 测试简单对话
- prompt = [self.user_message(test_prompt)]
- response = self.submit_prompt(prompt)
-
- result["success"] = True
- result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
-
- return result
-
- except Exception as e:
- result["message"] = f"Ollama连接测试失败: {str(e)}"
- return result
- def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
- """智能确定使用的模型"""
- # 优先级:运行时参数 > 配置文件 > 智能选择
- if kwargs.get("model", None) is not None:
- return kwargs.get("model")
- elif kwargs.get("engine", None) is not None:
- return kwargs.get("engine")
- elif self.config is not None and "engine" in self.config:
- return self.config["engine"]
- elif self.config is not None and "model" in self.config:
- return self.config["model"]
- else:
- # 智能选择模型
- if enable_thinking:
- # 优先选择推理模型
- try:
- available_models = self.list_models()
- reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
- if reasoning_models:
- return reasoning_models[0] # 选择第一个推理模型
- else:
- print("警告:未找到推理模型,使用默认模型")
- return self.model
- except Exception as e:
- print(f"获取模型列表时出错: {e},使用默认模型")
- return self.model
- else:
- # 根据 token 数量选择模型
- if num_tokens > 8000:
- # 长文本,选择长上下文模型
- try:
- available_models = self.list_models()
- long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
- if long_context_models:
- return long_context_models[0]
- except Exception as e:
- print(f"获取模型列表时出错: {e},使用默认模型")
-
- return self.model
- def _is_reasoning_model(self, model: str) -> bool:
- """检查是否为推理模型"""
- reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
- return any(keyword in model.lower() for keyword in reasoning_keywords)
- def _build_options(self, kwargs: dict, is_reasoning_model: bool, enable_thinking: bool = False) -> dict:
- """构建 Ollama options 参数"""
- options = {
- "temperature": self.temperature,
- "num_ctx": self.num_ctx,
- "num_predict": self.num_predict,
- "repeat_penalty": self.repeat_penalty,
- }
- # 过滤掉自定义参数,避免传递给 API
- filtered_kwargs = {k: v for k, v in kwargs.items()
- if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
- # 添加其他参数到 options 中
- for k, v in filtered_kwargs.items():
- options[k] = v
- # 推理功能参数调整(当启用推理功能时)
- if enable_thinking:
- # 启用推理功能时可能需要更多的预测token
- if options["num_predict"] == -1:
- options["num_predict"] = 2048
- # 降低重复惩罚,允许更多的推理重复
- options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
- # 对于推理模型,可以进一步优化参数
- if is_reasoning_model:
- # 推理模型可能需要更长的上下文
- options["num_ctx"] = max(options["num_ctx"], 8192)
- return options
- def _is_reasoning_content(self, content: str) -> bool:
- """判断内容是否为推理内容"""
- reasoning_patterns = [
- r'<think>.*?</think>',
- r'<reasoning>.*?</reasoning>',
- r'<analysis>.*?</analysis>',
- r'思考:',
- r'分析:',
- r'推理:'
- ]
- return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
- def _extract_reasoning(self, content: str) -> tuple:
- """提取推理内容和最终答案"""
- reasoning_patterns = [
- r'<think>(.*?)</think>',
- r'<reasoning>(.*?)</reasoning>',
- r'<analysis>(.*?)</analysis>',
- r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
- r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
- r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
- ]
-
- reasoning_content = ""
- final_content = content
-
- for pattern in reasoning_patterns:
- matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
- if matches:
- reasoning_content = "\n".join(matches)
- final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
- break
-
- # 如果没有找到明确的推理标记,但内容很长,尝试简单分割
- if not reasoning_content and len(content) > 500:
- lines = content.split('\n')
- if len(lines) > 10:
- # 假设前半部分是推理,后半部分是答案
- mid_point = len(lines) // 2
- potential_reasoning = '\n'.join(lines[:mid_point])
- potential_answer = '\n'.join(lines[mid_point:])
-
- # 简单启发式:如果前半部分包含推理关键词,则分离
- if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
- reasoning_content = potential_reasoning
- final_content = potential_answer
-
- return reasoning_content, final_content
- # Ollama 独特功能
- def list_models(self) -> List[str]:
- """列出可用的模型"""
- try:
- response = requests.get(f"{self.base_url}/api/tags", timeout=5) # 使用较短的超时时间
- response.raise_for_status()
- data = response.json()
- models = [model["name"] for model in data.get("models", [])]
- return models if models else [self.model] # 如果没有模型,返回默认模型
- except requests.exceptions.RequestException as e:
- print(f"获取模型列表失败: {e}")
- return [self.model] # 返回默认模型
- except Exception as e:
- print(f"解析模型列表失败: {e}")
- return [self.model] # 返回默认模型
- def pull_model(self, model_name: str) -> bool:
- """拉取模型"""
- try:
- print(f"正在拉取模型: {model_name}")
- response = requests.post(
- f"{self.base_url}/api/pull",
- json={"name": model_name},
- timeout=300 # 拉取模型可能需要较长时间
- )
- response.raise_for_status()
- print(f"✅ 模型 {model_name} 拉取成功")
- return True
- except requests.exceptions.RequestException as e:
- print(f"❌ 模型 {model_name} 拉取失败: {e}")
- return False
- def delete_model(self, model_name: str) -> bool:
- """删除模型"""
- try:
- response = requests.delete(
- f"{self.base_url}/api/delete",
- json={"name": model_name},
- timeout=self.timeout
- )
- response.raise_for_status()
- print(f"✅ 模型 {model_name} 删除成功")
- return True
- except requests.exceptions.RequestException as e:
- print(f"❌ 模型 {model_name} 删除失败: {e}")
- return False
- def get_model_info(self, model_name: str) -> Optional[Dict]:
- """获取模型信息"""
- try:
- response = requests.post(
- f"{self.base_url}/api/show",
- json={"name": model_name},
- timeout=self.timeout
- )
- response.raise_for_status()
- return response.json()
- except requests.exceptions.RequestException as e:
- print(f"获取模型信息失败: {e}")
- return None
- def get_system_info(self) -> Dict:
- """获取 Ollama 系统信息"""
- try:
- # 获取版本信息
- version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
- version_info = version_response.json() if version_response.status_code == 200 else {}
-
- # 获取模型列表
- models = self.list_models()
-
- return {
- "base_url": self.base_url,
- "version": version_info.get("version", "unknown"),
- "available_models": models,
- "current_model": self.model,
- "timeout": self.timeout,
- "num_ctx": self.num_ctx,
- "num_predict": self.num_predict,
- "repeat_penalty": self.repeat_penalty
- }
- except Exception as e:
- return {"error": str(e)}
|