utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # agent/utils.py
  2. """
  3. Agent相关的工具函数
  4. """
  5. import functools
  6. from typing import Dict, Any, Callable
  7. from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
  8. def handle_tool_errors(func: Callable) -> Callable:
  9. """
  10. 工具函数错误处理装饰器
  11. """
  12. @functools.wraps(func)
  13. def wrapper(*args, **kwargs) -> Dict[str, Any]:
  14. try:
  15. return func(*args, **kwargs)
  16. except Exception as e:
  17. print(f"[ERROR] 工具 {func.__name__} 执行失败: {str(e)}")
  18. return {
  19. "success": False,
  20. "error": f"工具执行异常: {str(e)}",
  21. "error_type": "tool_exception"
  22. }
  23. return wrapper
  24. class LLMWrapper:
  25. """自定义LLM的LangChain兼容包装器"""
  26. def __init__(self, llm_instance):
  27. self.llm = llm_instance
  28. self._model_name = getattr(llm_instance, 'model', 'custom_llm')
  29. def invoke(self, input_data, **kwargs):
  30. """LangChain invoke接口"""
  31. try:
  32. if isinstance(input_data, str):
  33. messages = [HumanMessage(content=input_data)]
  34. elif isinstance(input_data, list):
  35. messages = input_data
  36. else:
  37. messages = [HumanMessage(content=str(input_data))]
  38. # 转换消息格式
  39. prompt = []
  40. for msg in messages:
  41. if isinstance(msg, SystemMessage):
  42. prompt.append(self.llm.system_message(msg.content))
  43. elif isinstance(msg, HumanMessage):
  44. prompt.append(self.llm.user_message(msg.content))
  45. elif isinstance(msg, AIMessage):
  46. prompt.append(self.llm.assistant_message(msg.content))
  47. else:
  48. prompt.append(self.llm.user_message(str(msg.content)))
  49. # 调用底层LLM
  50. response = self.llm.submit_prompt(prompt, **kwargs)
  51. # 返回LangChain格式的结果
  52. return AIMessage(content=response)
  53. except Exception as e:
  54. print(f"[ERROR] LLM包装器调用失败: {str(e)}")
  55. return AIMessage(content=f"LLM调用失败: {str(e)}")
  56. @property
  57. def model_name(self) -> str:
  58. return self._model_name
  59. def bind_tools(self, tools):
  60. """绑定工具(用于支持工具调用)"""
  61. return self
  62. def get_compatible_llm():
  63. """获取兼容的LLM实例"""
  64. try:
  65. from common.utils import get_current_llm_config
  66. llm_config = get_current_llm_config()
  67. # 尝试使用标准的OpenAI兼容API
  68. if llm_config.get("base_url") and llm_config.get("api_key"):
  69. try:
  70. from langchain_openai import ChatOpenAI
  71. return ChatOpenAI(
  72. base_url=llm_config.get("base_url"),
  73. api_key=llm_config.get("api_key"),
  74. model=llm_config.get("model"),
  75. temperature=llm_config.get("temperature", 0.7)
  76. )
  77. except ImportError:
  78. print("[WARNING] langchain_openai 未安装,使用自定义包装器")
  79. # 使用自定义LLM包装器
  80. from customllm.qianwen_chat import QianWenChat
  81. custom_llm = QianWenChat(config=llm_config)
  82. return LLMWrapper(custom_llm)
  83. except Exception as e:
  84. print(f"[ERROR] 获取LLM失败: {str(e)}")
  85. # 返回基础包装器
  86. from common.utils import get_current_llm_config
  87. from customllm.qianwen_chat import QianWenChat
  88. llm_config = get_current_llm_config()
  89. custom_llm = QianWenChat(config=llm_config)
  90. return LLMWrapper(custom_llm)