qianwen_chat.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. from openai import OpenAI
  3. from .base_llm_chat import BaseLLMChat
  4. class QianWenChat(BaseLLMChat):
  5. """千问AI聊天实现"""
  6. def __init__(self, client=None, config=None):
  7. print("...QianWenChat init...")
  8. super().__init__(config=config)
  9. if "api_type" in config:
  10. raise Exception(
  11. "Passing api_type is now deprecated. Please pass an OpenAI client instead."
  12. )
  13. if "api_base" in config:
  14. raise Exception(
  15. "Passing api_base is now deprecated. Please pass an OpenAI client instead."
  16. )
  17. if "api_version" in config:
  18. raise Exception(
  19. "Passing api_version is now deprecated. Please pass an OpenAI client instead."
  20. )
  21. if client is not None:
  22. self.client = client
  23. return
  24. if config is None and client is None:
  25. self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
  26. return
  27. if "api_key" in config:
  28. if "base_url" not in config:
  29. self.client = OpenAI(api_key=config["api_key"],
  30. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
  31. else:
  32. self.client = OpenAI(api_key=config["api_key"],
  33. base_url=config["base_url"])
  34. def submit_prompt(self, prompt, **kwargs) -> str:
  35. if prompt is None:
  36. raise Exception("Prompt is None")
  37. if len(prompt) == 0:
  38. raise Exception("Prompt is empty")
  39. # Count the number of tokens in the message log
  40. # Use 4 as an approximation for the number of characters per token
  41. num_tokens = 0
  42. for message in prompt:
  43. num_tokens += len(message["content"]) / 4
  44. # 从配置和参数中获取enable_thinking设置
  45. # 优先使用参数中传入的值,如果没有则从配置中读取,默认为False
  46. enable_thinking = kwargs.get("enable_thinking", self.config.get("enable_thinking", False))
  47. # 从配置和参数中获取stream设置
  48. # 优先级:运行时参数 > 配置文件 > 默认值(False)
  49. stream_mode = kwargs.get("stream", self.config.get("stream", False))
  50. # 千问API约束:enable_thinking=True时必须stream=True
  51. # 如果stream=False但enable_thinking=True,则忽略enable_thinking
  52. if enable_thinking and not stream_mode:
  53. print("WARNING: enable_thinking=True 不生效,因为它需要 stream=True")
  54. enable_thinking = False
  55. # 创建一个干净的kwargs副本,移除可能导致API错误的自定义参数
  56. # 注意:enable_thinking和stream是千问API的有效参数,需要正确传递
  57. filtered_kwargs = {k: v for k, v in kwargs.items()
  58. if k not in ['model', 'engine']} # 只移除model和engine
  59. # 公共参数
  60. common_params = {
  61. "messages": prompt,
  62. "stop": None,
  63. "temperature": self.temperature,
  64. "stream": stream_mode, # 明确设置stream参数
  65. }
  66. # 千问OpenAI兼容接口要求enable_thinking参数放在extra_body中
  67. if enable_thinking:
  68. common_params["extra_body"] = {"enable_thinking": True}
  69. # 传递其他过滤后的参数(排除enable_thinking,因为我们已经单独处理了)
  70. for k, v in filtered_kwargs.items():
  71. if k not in ['enable_thinking', 'stream']: # 避免重复设置
  72. common_params[k] = v
  73. model = None
  74. # 确定使用的模型
  75. if kwargs.get("model", None) is not None:
  76. model = kwargs.get("model", None)
  77. common_params["model"] = model
  78. elif kwargs.get("engine", None) is not None:
  79. engine = kwargs.get("engine", None)
  80. common_params["engine"] = engine
  81. model = engine
  82. elif self.config is not None and "engine" in self.config:
  83. common_params["engine"] = self.config["engine"]
  84. model = self.config["engine"]
  85. elif self.config is not None and "model" in self.config:
  86. common_params["model"] = self.config["model"]
  87. model = self.config["model"]
  88. else:
  89. if num_tokens > 3500:
  90. model = "qwen-long"
  91. else:
  92. model = "qwen-plus"
  93. common_params["model"] = model
  94. print(f"\nUsing model {model} for {num_tokens} tokens (approx)")
  95. print(f"Enable thinking: {enable_thinking}, Stream mode: {stream_mode}")
  96. if stream_mode:
  97. # 流式处理模式
  98. if enable_thinking:
  99. print("使用流式处理模式,启用thinking功能")
  100. else:
  101. print("使用流式处理模式,不启用thinking功能")
  102. response_stream = self.client.chat.completions.create(**common_params)
  103. # 收集流式响应
  104. collected_thinking = []
  105. collected_content = []
  106. for chunk in response_stream:
  107. # 处理thinking部分(仅当enable_thinking=True时)
  108. if enable_thinking and hasattr(chunk, 'choices') and chunk.choices:
  109. delta = chunk.choices[0].delta
  110. if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
  111. collected_thinking.append(delta.reasoning_content)
  112. # 处理content部分
  113. if hasattr(chunk, 'choices') and chunk.choices:
  114. delta = chunk.choices[0].delta
  115. if hasattr(delta, 'content') and delta.content:
  116. collected_content.append(delta.content)
  117. # 可以在这里处理thinking的展示逻辑,如保存到日志等
  118. if enable_thinking and collected_thinking:
  119. print("Model thinking process:\n", "".join(collected_thinking))
  120. # 返回完整的内容
  121. return "".join(collected_content)
  122. else:
  123. # 非流式处理模式
  124. print("使用非流式处理模式")
  125. response = self.client.chat.completions.create(**common_params)
  126. # Find the first response from the chatbot that has text in it (some responses may not have text)
  127. for choice in response.choices:
  128. if "text" in choice:
  129. return choice.text
  130. # If no response with text is found, return the first response's content (which may be empty)
  131. return response.choices[0].message.content