ollama_chat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import requests
  2. import json
  3. import re
  4. from typing import List, Dict, Any, Optional
  5. from .base_llm_chat import BaseLLMChat
  6. class OllamaChat(BaseLLMChat):
  7. """Ollama AI聊天实现"""
  8. def __init__(self, config=None):
  9. print("...OllamaChat init...")
  10. super().__init__(config=config)
  11. # Ollama特定的配置参数
  12. self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
  13. self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
  14. self.timeout = config.get("timeout", 60) if config else 60
  15. # Ollama 特定参数
  16. self.num_ctx = config.get("num_ctx", 4096) if config else 4096 # 上下文长度
  17. self.num_predict = config.get("num_predict", -1) if config else -1 # 预测token数量
  18. self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1 # 重复惩罚
  19. # 验证连接
  20. if config and config.get("auto_check_connection", True):
  21. self._check_ollama_health()
  22. def _check_ollama_health(self) -> bool:
  23. """检查 Ollama 服务健康状态"""
  24. try:
  25. response = requests.get(f"{self.base_url}/api/tags", timeout=5)
  26. if response.status_code == 200:
  27. print(f"✅ Ollama 服务连接正常: {self.base_url}")
  28. return True
  29. else:
  30. print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
  31. return False
  32. except requests.exceptions.RequestException as e:
  33. print(f"❌ Ollama 服务连接失败: {e}")
  34. return False
  35. def submit_prompt(self, prompt, **kwargs) -> str:
  36. if prompt is None:
  37. raise Exception("Prompt is None")
  38. if len(prompt) == 0:
  39. raise Exception("Prompt is empty")
  40. # 计算token数量估计
  41. num_tokens = 0
  42. for message in prompt:
  43. num_tokens += len(message["content"]) / 4
  44. # 获取 stream 参数
  45. stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
  46. # 获取 enable_thinking 参数
  47. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
  48. # Ollama 约束:enable_thinking=True时建议使用stream=True
  49. # 如果stream=False但enable_thinking=True,则忽略enable_thinking
  50. if enable_thinking and not stream_mode:
  51. print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
  52. enable_thinking = False
  53. # 智能模型选择
  54. model = self._determine_model(kwargs, enable_thinking, num_tokens)
  55. # 检查是否为推理模型
  56. is_reasoning_model = self._is_reasoning_model(model)
  57. # 模型兼容性提示(但不强制切换)
  58. if enable_thinking and not is_reasoning_model:
  59. print(f"提示:模型 {model} 可能不支持推理功能,推理相关参数将被忽略")
  60. print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
  61. print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
  62. # 准备Ollama API请求
  63. url = f"{self.base_url}/api/chat"
  64. payload = {
  65. "model": model,
  66. "messages": prompt,
  67. "stream": stream_mode,
  68. "options": self._build_options(kwargs, is_reasoning_model)
  69. }
  70. try:
  71. if stream_mode:
  72. # 流式处理模式
  73. if is_reasoning_model and enable_thinking:
  74. print("使用流式处理模式,启用推理功能")
  75. else:
  76. print("使用流式处理模式,常规聊天")
  77. return self._handle_stream_response(url, payload, is_reasoning_model and enable_thinking)
  78. else:
  79. # 非流式处理模式
  80. if is_reasoning_model and enable_thinking:
  81. print("使用非流式处理模式,启用推理功能")
  82. else:
  83. print("使用非流式处理模式,常规聊天")
  84. return self._handle_non_stream_response(url, payload, is_reasoning_model and enable_thinking)
  85. except requests.exceptions.RequestException as e:
  86. print(f"Ollama API请求失败: {e}")
  87. raise Exception(f"Ollama API调用失败: {str(e)}")
  88. def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
  89. """处理流式响应"""
  90. response = requests.post(
  91. url,
  92. json=payload,
  93. timeout=self.timeout,
  94. headers={"Content-Type": "application/json"},
  95. stream=True
  96. )
  97. response.raise_for_status()
  98. collected_content = []
  99. for line in response.iter_lines():
  100. if line:
  101. try:
  102. chunk_data = json.loads(line.decode('utf-8'))
  103. if 'message' in chunk_data and 'content' in chunk_data['message']:
  104. content = chunk_data['message']['content']
  105. collected_content.append(content)
  106. # 检查是否完成
  107. if chunk_data.get('done', False):
  108. break
  109. except json.JSONDecodeError:
  110. continue
  111. # 合并所有内容
  112. full_content = "".join(collected_content)
  113. # 如果启用推理功能,尝试分离推理内容和最终答案
  114. if enable_reasoning:
  115. reasoning_content, final_content = self._extract_reasoning(full_content)
  116. if reasoning_content:
  117. print("Model reasoning process:\n", reasoning_content)
  118. return final_content
  119. return full_content
  120. def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
  121. """处理非流式响应"""
  122. response = requests.post(
  123. url,
  124. json=payload,
  125. timeout=self.timeout,
  126. headers={"Content-Type": "application/json"}
  127. )
  128. response.raise_for_status()
  129. result = response.json()
  130. content = result["message"]["content"]
  131. if enable_reasoning:
  132. # 尝试分离推理内容和最终答案
  133. reasoning_content, final_content = self._extract_reasoning(content)
  134. if reasoning_content:
  135. print("Model reasoning process:\n", reasoning_content)
  136. return final_content
  137. return content
  138. def test_connection(self, test_prompt="你好") -> dict:
  139. """测试Ollama连接"""
  140. result = {
  141. "success": False,
  142. "model": self.model,
  143. "base_url": self.base_url,
  144. "message": "",
  145. "available_models": [],
  146. "ollama_version": None
  147. }
  148. try:
  149. # 检查服务健康状态
  150. if not self._check_ollama_health():
  151. result["message"] = "Ollama 服务不可用"
  152. return result
  153. # 获取可用模型列表
  154. try:
  155. result["available_models"] = self.list_models()
  156. # 检查目标模型是否存在
  157. if self.model not in result["available_models"]:
  158. print(f"警告:模型 {self.model} 不存在,尝试拉取...")
  159. if not self.pull_model(self.model):
  160. result["message"] = f"模型 {self.model} 不存在且拉取失败"
  161. return result
  162. except Exception as e:
  163. print(f"获取模型列表失败: {e}")
  164. result["available_models"] = [self.model]
  165. print(f"测试Ollama连接 - 模型: {self.model}")
  166. print(f"Ollama服务地址: {self.base_url}")
  167. print(f"可用模型: {', '.join(result['available_models'])}")
  168. # 测试简单对话
  169. prompt = [self.user_message(test_prompt)]
  170. response = self.submit_prompt(prompt)
  171. result["success"] = True
  172. result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
  173. return result
  174. except Exception as e:
  175. result["message"] = f"Ollama连接测试失败: {str(e)}"
  176. return result
  177. def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
  178. """智能确定使用的模型"""
  179. # 优先级:运行时参数 > 配置文件 > 智能选择
  180. if kwargs.get("model", None) is not None:
  181. return kwargs.get("model")
  182. elif kwargs.get("engine", None) is not None:
  183. return kwargs.get("engine")
  184. elif self.config is not None and "engine" in self.config:
  185. return self.config["engine"]
  186. elif self.config is not None and "model" in self.config:
  187. return self.config["model"]
  188. else:
  189. # 智能选择模型
  190. if enable_thinking:
  191. # 优先选择推理模型
  192. try:
  193. available_models = self.list_models()
  194. reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
  195. if reasoning_models:
  196. return reasoning_models[0] # 选择第一个推理模型
  197. else:
  198. print("警告:未找到推理模型,使用默认模型")
  199. return self.model
  200. except Exception as e:
  201. print(f"获取模型列表时出错: {e},使用默认模型")
  202. return self.model
  203. else:
  204. # 根据 token 数量选择模型
  205. if num_tokens > 8000:
  206. # 长文本,选择长上下文模型
  207. try:
  208. available_models = self.list_models()
  209. long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
  210. if long_context_models:
  211. return long_context_models[0]
  212. except Exception as e:
  213. print(f"获取模型列表时出错: {e},使用默认模型")
  214. return self.model
  215. def _is_reasoning_model(self, model: str) -> bool:
  216. """检查是否为推理模型"""
  217. reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
  218. return any(keyword in model.lower() for keyword in reasoning_keywords)
  219. def _build_options(self, kwargs: dict, is_reasoning_model: bool) -> dict:
  220. """构建 Ollama options 参数"""
  221. options = {
  222. "temperature": self.temperature,
  223. "num_ctx": self.num_ctx,
  224. "num_predict": self.num_predict,
  225. "repeat_penalty": self.repeat_penalty,
  226. }
  227. # 过滤掉自定义参数,避免传递给 API
  228. filtered_kwargs = {k: v for k, v in kwargs.items()
  229. if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
  230. # 添加其他参数到 options 中
  231. for k, v in filtered_kwargs.items():
  232. options[k] = v
  233. # 推理模型特定参数调整
  234. if is_reasoning_model:
  235. # 推理模型可能需要更多的预测token
  236. if options["num_predict"] == -1:
  237. options["num_predict"] = 2048
  238. # 降低重复惩罚,允许更多的推理重复
  239. options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
  240. return options
  241. def _is_reasoning_content(self, content: str) -> bool:
  242. """判断内容是否为推理内容"""
  243. reasoning_patterns = [
  244. r'<think>.*?</think>',
  245. r'<reasoning>.*?</reasoning>',
  246. r'<analysis>.*?</analysis>',
  247. r'思考:',
  248. r'分析:',
  249. r'推理:'
  250. ]
  251. return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
  252. def _extract_reasoning(self, content: str) -> tuple:
  253. """提取推理内容和最终答案"""
  254. reasoning_patterns = [
  255. r'<think>(.*?)</think>',
  256. r'<reasoning>(.*?)</reasoning>',
  257. r'<analysis>(.*?)</analysis>',
  258. r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
  259. r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
  260. r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
  261. ]
  262. reasoning_content = ""
  263. final_content = content
  264. for pattern in reasoning_patterns:
  265. matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
  266. if matches:
  267. reasoning_content = "\n".join(matches)
  268. final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
  269. break
  270. # 如果没有找到明确的推理标记,但内容很长,尝试简单分割
  271. if not reasoning_content and len(content) > 500:
  272. lines = content.split('\n')
  273. if len(lines) > 10:
  274. # 假设前半部分是推理,后半部分是答案
  275. mid_point = len(lines) // 2
  276. potential_reasoning = '\n'.join(lines[:mid_point])
  277. potential_answer = '\n'.join(lines[mid_point:])
  278. # 简单启发式:如果前半部分包含推理关键词,则分离
  279. if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
  280. reasoning_content = potential_reasoning
  281. final_content = potential_answer
  282. return reasoning_content, final_content
  283. # Ollama 独特功能
  284. def list_models(self) -> List[str]:
  285. """列出可用的模型"""
  286. try:
  287. response = requests.get(f"{self.base_url}/api/tags", timeout=5) # 使用较短的超时时间
  288. response.raise_for_status()
  289. data = response.json()
  290. models = [model["name"] for model in data.get("models", [])]
  291. return models if models else [self.model] # 如果没有模型,返回默认模型
  292. except requests.exceptions.RequestException as e:
  293. print(f"获取模型列表失败: {e}")
  294. return [self.model] # 返回默认模型
  295. except Exception as e:
  296. print(f"解析模型列表失败: {e}")
  297. return [self.model] # 返回默认模型
  298. def pull_model(self, model_name: str) -> bool:
  299. """拉取模型"""
  300. try:
  301. print(f"正在拉取模型: {model_name}")
  302. response = requests.post(
  303. f"{self.base_url}/api/pull",
  304. json={"name": model_name},
  305. timeout=300 # 拉取模型可能需要较长时间
  306. )
  307. response.raise_for_status()
  308. print(f"✅ 模型 {model_name} 拉取成功")
  309. return True
  310. except requests.exceptions.RequestException as e:
  311. print(f"❌ 模型 {model_name} 拉取失败: {e}")
  312. return False
  313. def delete_model(self, model_name: str) -> bool:
  314. """删除模型"""
  315. try:
  316. response = requests.delete(
  317. f"{self.base_url}/api/delete",
  318. json={"name": model_name},
  319. timeout=self.timeout
  320. )
  321. response.raise_for_status()
  322. print(f"✅ 模型 {model_name} 删除成功")
  323. return True
  324. except requests.exceptions.RequestException as e:
  325. print(f"❌ 模型 {model_name} 删除失败: {e}")
  326. return False
  327. def get_model_info(self, model_name: str) -> Optional[Dict]:
  328. """获取模型信息"""
  329. try:
  330. response = requests.post(
  331. f"{self.base_url}/api/show",
  332. json={"name": model_name},
  333. timeout=self.timeout
  334. )
  335. response.raise_for_status()
  336. return response.json()
  337. except requests.exceptions.RequestException as e:
  338. print(f"获取模型信息失败: {e}")
  339. return None
  340. def get_system_info(self) -> Dict:
  341. """获取 Ollama 系统信息"""
  342. try:
  343. # 获取版本信息
  344. version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
  345. version_info = version_response.json() if version_response.status_code == 200 else {}
  346. # 获取模型列表
  347. models = self.list_models()
  348. return {
  349. "base_url": self.base_url,
  350. "version": version_info.get("version", "unknown"),
  351. "available_models": models,
  352. "current_model": self.model,
  353. "timeout": self.timeout,
  354. "num_ctx": self.num_ctx,
  355. "num_predict": self.num_predict,
  356. "repeat_penalty": self.repeat_penalty
  357. }
  358. except Exception as e:
  359. return {"error": str(e)}