ollama_chat.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import requests
  2. import json
  3. import re
  4. from typing import List, Dict, Any, Optional
  5. from .base_llm_chat import BaseLLMChat
  6. class OllamaChat(BaseLLMChat):
  7. """Ollama AI聊天实现"""
  8. def __init__(self, config=None):
  9. print("...OllamaChat init...")
  10. super().__init__(config=config)
  11. # Ollama特定的配置参数
  12. self.base_url = config.get("base_url", "http://localhost:11434") if config else "http://localhost:11434"
  13. self.model = config.get("model", "qwen2.5:7b") if config else "qwen2.5:7b"
  14. self.timeout = config.get("timeout", 60) if config else 60
  15. # Ollama 特定参数
  16. self.num_ctx = config.get("num_ctx", 4096) if config else 4096 # 上下文长度
  17. self.num_predict = config.get("num_predict", -1) if config else -1 # 预测token数量
  18. self.repeat_penalty = config.get("repeat_penalty", 1.1) if config else 1.1 # 重复惩罚
  19. # 验证连接
  20. if config and config.get("auto_check_connection", True):
  21. self._check_ollama_health()
  22. def _check_ollama_health(self) -> bool:
  23. """检查 Ollama 服务健康状态"""
  24. try:
  25. response = requests.get(f"{self.base_url}/api/tags", timeout=5)
  26. if response.status_code == 200:
  27. print(f"✅ Ollama 服务连接正常: {self.base_url}")
  28. return True
  29. else:
  30. print(f"⚠️ Ollama 服务响应异常: {response.status_code}")
  31. return False
  32. except requests.exceptions.RequestException as e:
  33. print(f"❌ Ollama 服务连接失败: {e}")
  34. return False
  35. def submit_prompt(self, prompt, **kwargs) -> str:
  36. if prompt is None:
  37. raise Exception("Prompt is None")
  38. if len(prompt) == 0:
  39. raise Exception("Prompt is empty")
  40. # 计算token数量估计
  41. num_tokens = 0
  42. for message in prompt:
  43. num_tokens += len(message["content"]) / 4
  44. # 获取 stream 参数
  45. stream_mode = kwargs.get("stream", self.config.get("stream", False) if self.config else False)
  46. # 获取 enable_thinking 参数
  47. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False) if self.config else False)
  48. # Ollama 约束:enable_thinking=True时建议使用stream=True
  49. # 如果stream=False但enable_thinking=True,则忽略enable_thinking
  50. if enable_thinking and not stream_mode:
  51. print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
  52. enable_thinking = False
  53. # 智能模型选择
  54. model = self._determine_model(kwargs, enable_thinking, num_tokens)
  55. # 检查是否为推理模型
  56. is_reasoning_model = self._is_reasoning_model(model)
  57. # 模型兼容性提示(但不强制切换)
  58. if enable_thinking and not is_reasoning_model:
  59. print(f"提示:模型 {model} 不是专门的推理模型,但仍会尝试启用推理功能")
  60. print(f"\nUsing Ollama model {model} for {num_tokens} tokens (approx)")
  61. print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
  62. # 准备Ollama API请求
  63. url = f"{self.base_url}/api/chat"
  64. payload = {
  65. "model": model,
  66. "messages": prompt,
  67. "stream": stream_mode,
  68. "think": enable_thinking, # Ollama API 使用 think 参数控制推理功能
  69. "options": self._build_options(kwargs, is_reasoning_model, enable_thinking)
  70. }
  71. try:
  72. if stream_mode:
  73. # 流式处理模式
  74. if enable_thinking:
  75. print("使用流式处理模式,启用推理功能")
  76. else:
  77. print("使用流式处理模式,常规聊天")
  78. return self._handle_stream_response(url, payload, enable_thinking)
  79. else:
  80. # 非流式处理模式
  81. if enable_thinking:
  82. print("使用非流式处理模式,启用推理功能")
  83. else:
  84. print("使用非流式处理模式,常规聊天")
  85. return self._handle_non_stream_response(url, payload, enable_thinking)
  86. except requests.exceptions.RequestException as e:
  87. print(f"Ollama API请求失败: {e}")
  88. raise Exception(f"Ollama API调用失败: {str(e)}")
  89. def _handle_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
  90. """处理流式响应"""
  91. response = requests.post(
  92. url,
  93. json=payload,
  94. timeout=self.timeout,
  95. headers={"Content-Type": "application/json"},
  96. stream=True
  97. )
  98. response.raise_for_status()
  99. collected_content = []
  100. for line in response.iter_lines():
  101. if line:
  102. try:
  103. chunk_data = json.loads(line.decode('utf-8'))
  104. if 'message' in chunk_data and 'content' in chunk_data['message']:
  105. content = chunk_data['message']['content']
  106. collected_content.append(content)
  107. # 检查是否完成
  108. if chunk_data.get('done', False):
  109. break
  110. except json.JSONDecodeError:
  111. continue
  112. # 合并所有内容
  113. full_content = "".join(collected_content)
  114. # 如果启用推理功能,尝试分离推理内容和最终答案
  115. if enable_reasoning:
  116. reasoning_content, final_content = self._extract_reasoning(full_content)
  117. if reasoning_content:
  118. print("Model reasoning process:\n", reasoning_content)
  119. return final_content
  120. return full_content
  121. def _handle_non_stream_response(self, url: str, payload: dict, enable_reasoning: bool) -> str:
  122. """处理非流式响应"""
  123. response = requests.post(
  124. url,
  125. json=payload,
  126. timeout=self.timeout,
  127. headers={"Content-Type": "application/json"}
  128. )
  129. response.raise_for_status()
  130. result = response.json()
  131. content = result["message"]["content"]
  132. if enable_reasoning:
  133. # 尝试分离推理内容和最终答案
  134. reasoning_content, final_content = self._extract_reasoning(content)
  135. if reasoning_content:
  136. print("Model reasoning process:\n", reasoning_content)
  137. return final_content
  138. return content
  139. def test_connection(self, test_prompt="你好") -> dict:
  140. """测试Ollama连接"""
  141. result = {
  142. "success": False,
  143. "model": self.model,
  144. "base_url": self.base_url,
  145. "message": "",
  146. "available_models": [],
  147. "ollama_version": None
  148. }
  149. try:
  150. # 检查服务健康状态
  151. if not self._check_ollama_health():
  152. result["message"] = "Ollama 服务不可用"
  153. return result
  154. # 获取可用模型列表
  155. try:
  156. result["available_models"] = self.list_models()
  157. # 检查目标模型是否存在
  158. if self.model not in result["available_models"]:
  159. print(f"警告:模型 {self.model} 不存在,尝试拉取...")
  160. if not self.pull_model(self.model):
  161. result["message"] = f"模型 {self.model} 不存在且拉取失败"
  162. return result
  163. except Exception as e:
  164. print(f"获取模型列表失败: {e}")
  165. result["available_models"] = [self.model]
  166. print(f"测试Ollama连接 - 模型: {self.model}")
  167. print(f"Ollama服务地址: {self.base_url}")
  168. print(f"可用模型: {', '.join(result['available_models'])}")
  169. # 测试简单对话
  170. prompt = [self.user_message(test_prompt)]
  171. response = self.submit_prompt(prompt)
  172. result["success"] = True
  173. result["message"] = f"Ollama连接测试成功,响应: {response[:50]}..."
  174. return result
  175. except Exception as e:
  176. result["message"] = f"Ollama连接测试失败: {str(e)}"
  177. return result
  178. def _determine_model(self, kwargs: dict, enable_thinking: bool, num_tokens: int) -> str:
  179. """智能确定使用的模型"""
  180. # 优先级:运行时参数 > 配置文件 > 智能选择
  181. if kwargs.get("model", None) is not None:
  182. return kwargs.get("model")
  183. elif kwargs.get("engine", None) is not None:
  184. return kwargs.get("engine")
  185. elif self.config is not None and "engine" in self.config:
  186. return self.config["engine"]
  187. elif self.config is not None and "model" in self.config:
  188. return self.config["model"]
  189. else:
  190. # 智能选择模型
  191. if enable_thinking:
  192. # 优先选择推理模型
  193. try:
  194. available_models = self.list_models()
  195. reasoning_models = [m for m in available_models if self._is_reasoning_model(m)]
  196. if reasoning_models:
  197. return reasoning_models[0] # 选择第一个推理模型
  198. else:
  199. print("警告:未找到推理模型,使用默认模型")
  200. return self.model
  201. except Exception as e:
  202. print(f"获取模型列表时出错: {e},使用默认模型")
  203. return self.model
  204. else:
  205. # 根据 token 数量选择模型
  206. if num_tokens > 8000:
  207. # 长文本,选择长上下文模型
  208. try:
  209. available_models = self.list_models()
  210. long_context_models = [m for m in available_models if any(keyword in m.lower() for keyword in ['long', '32k', '128k'])]
  211. if long_context_models:
  212. return long_context_models[0]
  213. except Exception as e:
  214. print(f"获取模型列表时出错: {e},使用默认模型")
  215. return self.model
  216. def _is_reasoning_model(self, model: str) -> bool:
  217. """检查是否为推理模型"""
  218. reasoning_keywords = ['r1', 'reasoning', 'think', 'cot', 'chain-of-thought']
  219. return any(keyword in model.lower() for keyword in reasoning_keywords)
  220. def _build_options(self, kwargs: dict, is_reasoning_model: bool, enable_thinking: bool = False) -> dict:
  221. """构建 Ollama options 参数"""
  222. options = {
  223. "temperature": self.temperature,
  224. "num_ctx": self.num_ctx,
  225. "num_predict": self.num_predict,
  226. "repeat_penalty": self.repeat_penalty,
  227. }
  228. # 过滤掉自定义参数,避免传递给 API
  229. filtered_kwargs = {k: v for k, v in kwargs.items()
  230. if k not in ['model', 'engine', 'enable_thinking', 'stream', 'timeout']}
  231. # 添加其他参数到 options 中
  232. for k, v in filtered_kwargs.items():
  233. options[k] = v
  234. # 推理功能参数调整(当启用推理功能时)
  235. if enable_thinking:
  236. # 启用推理功能时可能需要更多的预测token
  237. if options["num_predict"] == -1:
  238. options["num_predict"] = 2048
  239. # 降低重复惩罚,允许更多的推理重复
  240. options["repeat_penalty"] = min(options["repeat_penalty"], 1.05)
  241. # 对于推理模型,可以进一步优化参数
  242. if is_reasoning_model:
  243. # 推理模型可能需要更长的上下文
  244. options["num_ctx"] = max(options["num_ctx"], 8192)
  245. return options
  246. def _is_reasoning_content(self, content: str) -> bool:
  247. """判断内容是否为推理内容"""
  248. reasoning_patterns = [
  249. r'<think>.*?</think>',
  250. r'<reasoning>.*?</reasoning>',
  251. r'<analysis>.*?</analysis>',
  252. r'思考:',
  253. r'分析:',
  254. r'推理:'
  255. ]
  256. return any(re.search(pattern, content, re.DOTALL | re.IGNORECASE) for pattern in reasoning_patterns)
  257. def _extract_reasoning(self, content: str) -> tuple:
  258. """提取推理内容和最终答案"""
  259. reasoning_patterns = [
  260. r'<think>(.*?)</think>',
  261. r'<reasoning>(.*?)</reasoning>',
  262. r'<analysis>(.*?)</analysis>',
  263. r'思考:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
  264. r'分析:(.*?)(?=\n\n|\n[^思考分析推理]|$)',
  265. r'推理:(.*?)(?=\n\n|\n[^思考分析推理]|$)'
  266. ]
  267. reasoning_content = ""
  268. final_content = content
  269. for pattern in reasoning_patterns:
  270. matches = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
  271. if matches:
  272. reasoning_content = "\n".join(matches)
  273. final_content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE).strip()
  274. break
  275. # 如果没有找到明确的推理标记,但内容很长,尝试简单分割
  276. if not reasoning_content and len(content) > 500:
  277. lines = content.split('\n')
  278. if len(lines) > 10:
  279. # 假设前半部分是推理,后半部分是答案
  280. mid_point = len(lines) // 2
  281. potential_reasoning = '\n'.join(lines[:mid_point])
  282. potential_answer = '\n'.join(lines[mid_point:])
  283. # 简单启发式:如果前半部分包含推理关键词,则分离
  284. if any(keyword in potential_reasoning for keyword in ['思考', '分析', '推理', '因为', '所以', '首先', '然后']):
  285. reasoning_content = potential_reasoning
  286. final_content = potential_answer
  287. return reasoning_content, final_content
  288. # Ollama 独特功能
  289. def list_models(self) -> List[str]:
  290. """列出可用的模型"""
  291. try:
  292. response = requests.get(f"{self.base_url}/api/tags", timeout=5) # 使用较短的超时时间
  293. response.raise_for_status()
  294. data = response.json()
  295. models = [model["name"] for model in data.get("models", [])]
  296. return models if models else [self.model] # 如果没有模型,返回默认模型
  297. except requests.exceptions.RequestException as e:
  298. print(f"获取模型列表失败: {e}")
  299. return [self.model] # 返回默认模型
  300. except Exception as e:
  301. print(f"解析模型列表失败: {e}")
  302. return [self.model] # 返回默认模型
  303. def pull_model(self, model_name: str) -> bool:
  304. """拉取模型"""
  305. try:
  306. print(f"正在拉取模型: {model_name}")
  307. response = requests.post(
  308. f"{self.base_url}/api/pull",
  309. json={"name": model_name},
  310. timeout=300 # 拉取模型可能需要较长时间
  311. )
  312. response.raise_for_status()
  313. print(f"✅ 模型 {model_name} 拉取成功")
  314. return True
  315. except requests.exceptions.RequestException as e:
  316. print(f"❌ 模型 {model_name} 拉取失败: {e}")
  317. return False
  318. def delete_model(self, model_name: str) -> bool:
  319. """删除模型"""
  320. try:
  321. response = requests.delete(
  322. f"{self.base_url}/api/delete",
  323. json={"name": model_name},
  324. timeout=self.timeout
  325. )
  326. response.raise_for_status()
  327. print(f"✅ 模型 {model_name} 删除成功")
  328. return True
  329. except requests.exceptions.RequestException as e:
  330. print(f"❌ 模型 {model_name} 删除失败: {e}")
  331. return False
  332. def get_model_info(self, model_name: str) -> Optional[Dict]:
  333. """获取模型信息"""
  334. try:
  335. response = requests.post(
  336. f"{self.base_url}/api/show",
  337. json={"name": model_name},
  338. timeout=self.timeout
  339. )
  340. response.raise_for_status()
  341. return response.json()
  342. except requests.exceptions.RequestException as e:
  343. print(f"获取模型信息失败: {e}")
  344. return None
  345. def get_system_info(self) -> Dict:
  346. """获取 Ollama 系统信息"""
  347. try:
  348. # 获取版本信息
  349. version_response = requests.get(f"{self.base_url}/api/version", timeout=self.timeout)
  350. version_info = version_response.json() if version_response.status_code == 200 else {}
  351. # 获取模型列表
  352. models = self.list_models()
  353. return {
  354. "base_url": self.base_url,
  355. "version": version_info.get("version", "unknown"),
  356. "available_models": models,
  357. "current_model": self.model,
  358. "timeout": self.timeout,
  359. "num_ctx": self.num_ctx,
  360. "num_predict": self.num_predict,
  361. "repeat_penalty": self.repeat_penalty
  362. }
  363. except Exception as e:
  364. return {"error": str(e)}