| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 | 
							- # agent/utils.py
 
- """
 
- Agent相关的工具函数
 
- """
 
- import functools
 
- import json
 
- from typing import Dict, Any, Callable, List, Optional
 
- from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage
 
- from langchain_core.tools import BaseTool
 
- def handle_tool_errors(func: Callable) -> Callable:
 
-     """
 
-     工具函数错误处理装饰器
 
-     """
 
-     @functools.wraps(func)
 
-     def wrapper(*args, **kwargs) -> Dict[str, Any]:
 
-         try:
 
-             return func(*args, **kwargs)
 
-         except Exception as e:
 
-             print(f"[ERROR] 工具 {func.__name__} 执行失败: {str(e)}")
 
-             return {
 
-                 "success": False,
 
-                 "error": f"工具执行异常: {str(e)}",
 
-                 "error_type": "tool_exception"
 
-             }
 
-     return wrapper
 
- class LLMWrapper:
 
-     """自定义LLM的LangChain兼容包装器,支持工具调用"""
 
-     
 
-     def __init__(self, llm_instance):
 
-         self.llm = llm_instance
 
-         self._model_name = getattr(llm_instance, 'model', 'custom_llm')
 
-         self._bound_tools = []
 
-     
 
-     def invoke(self, input_data, **kwargs):
 
-         """LangChain invoke接口"""
 
-         try:
 
-             if isinstance(input_data, str):
 
-                 messages = [HumanMessage(content=input_data)]
 
-             elif isinstance(input_data, list):
 
-                 messages = input_data
 
-             else:
 
-                 messages = [HumanMessage(content=str(input_data))]
 
-             
 
-             # 检查是否需要工具调用
 
-             if self._bound_tools and self._should_use_tools(messages):
 
-                 return self._invoke_with_tools(messages, **kwargs)
 
-             else:
 
-                 return self._invoke_without_tools(messages, **kwargs)
 
-                 
 
-         except Exception as e:
 
-             print(f"[ERROR] LLM包装器调用失败: {str(e)}")
 
-             return AIMessage(content=f"LLM调用失败: {str(e)}")
 
-     
 
-     def _should_use_tools(self, messages: List[BaseMessage]) -> bool:
 
-         """判断是否应该使用工具"""
 
-         # 检查最后一条消息是否包含工具相关的指令
 
-         if messages:
 
-             last_message = messages[-1]
 
-             if isinstance(last_message, HumanMessage):
 
-                 content = last_message.content.lower()
 
-                 # 检查是否包含工具相关的关键词
 
-                 tool_keywords = ["生成sql", "执行sql", "generate sql", "execute sql", "查询", "数据库"]
 
-                 return any(keyword in content for keyword in tool_keywords)
 
-         return True  # 默认使用工具
 
-     
 
-     def _invoke_with_tools(self, messages: List[BaseMessage], **kwargs):
 
-         """使用工具调用的方式"""
 
-         try:
 
-             # 构建工具调用提示
 
-             tool_prompt = self._build_tool_prompt(messages)
 
-             
 
-             # 调用底层LLM
 
-             response = self.llm.submit_prompt(tool_prompt, **kwargs)
 
-             
 
-             # 解析工具调用
 
-             tool_calls = self._parse_tool_calls(response)
 
-             
 
-             if tool_calls:
 
-                 # 如果有工具调用,返回包含工具调用的AIMessage
 
-                 return AIMessage(
 
-                     content=response,
 
-                     tool_calls=tool_calls
 
-                 )
 
-             else:
 
-                 # 没有工具调用,返回普通响应
 
-                 return AIMessage(content=response)
 
-                 
 
-         except Exception as e:
 
-             print(f"[ERROR] 工具调用失败: {str(e)}")
 
-             return self._invoke_without_tools(messages, **kwargs)
 
-     
 
-     def _invoke_without_tools(self, messages: List[BaseMessage], **kwargs):
 
-         """不使用工具的普通调用"""
 
-         # 转换消息格式
 
-         prompt = []
 
-         for msg in messages:
 
-             if isinstance(msg, SystemMessage):
 
-                 prompt.append(self.llm.system_message(msg.content))
 
-             elif isinstance(msg, HumanMessage):
 
-                 prompt.append(self.llm.user_message(msg.content))
 
-             elif isinstance(msg, AIMessage):
 
-                 prompt.append(self.llm.assistant_message(msg.content))
 
-             else:
 
-                 prompt.append(self.llm.user_message(str(msg.content)))
 
-         
 
-         # 调用底层LLM
 
-         response = self.llm.submit_prompt(prompt, **kwargs)
 
-         
 
-         # 返回LangChain格式的结果
 
-         return AIMessage(content=response)
 
-     
 
-     def _build_tool_prompt(self, messages: List[BaseMessage]) -> List[Dict[str, str]]:
 
-         """构建包含工具信息的提示"""
 
-         prompt = []
 
-         
 
-         # 添加系统消息,包含工具定义
 
-         system_content = self._get_system_message_with_tools(messages)
 
-         prompt.append(self.llm.system_message(system_content))
 
-         
 
-         # 添加用户消息
 
-         for msg in messages:
 
-             if isinstance(msg, HumanMessage):
 
-                 prompt.append(self.llm.user_message(msg.content))
 
-             elif isinstance(msg, AIMessage) and not isinstance(msg, SystemMessage):
 
-                 prompt.append(self.llm.assistant_message(msg.content))
 
-         
 
-         return prompt
 
-     
 
-     def _get_system_message_with_tools(self, messages: List[BaseMessage]) -> str:
 
-         """获取包含工具定义的系统消息"""
 
-         # 查找原始系统消息
 
-         original_system = ""
 
-         for msg in messages:
 
-             if isinstance(msg, SystemMessage):
 
-                 original_system = msg.content
 
-                 break
 
-         
 
-         # 构建工具定义
 
-         tool_definitions = []
 
-         for tool in self._bound_tools:
 
-             tool_def = {
 
-                 "name": tool.name,
 
-                 "description": tool.description,
 
-                 "parameters": getattr(tool, 'args_schema', {})
 
-             }
 
-             tool_definitions.append(f"- {tool.name}: {tool.description}")
 
-         
 
-         # 组合系统消息
 
-         if tool_definitions:
 
-             tools_text = "\n".join(tool_definitions)
 
-             return f"""{original_system}
 
- 你有以下工具可以使用:
 
- {tools_text}
 
- 使用工具时,请明确说明你要调用哪个工具以及需要的参数。对于数据库查询问题,请按照以下步骤:
 
- 1. 使用 generate_sql 工具生成SQL查询
 
- 2. 使用 execute_sql 工具执行SQL查询
 
- 3. 使用 generate_summary 工具生成结果摘要
 
- 请直接开始执行工具调用,不要只是提供指导。"""
 
-         else:
 
-             return original_system
 
-     
 
-     def _parse_tool_calls(self, response: str) -> List[Dict[str, Any]]:
 
-         """解析LLM响应中的工具调用"""
 
-         tool_calls = []
 
-         
 
-         # 简单的工具调用解析逻辑
 
-         # 这里可以根据实际的LLM响应格式进行调整
 
-         
 
-         response_lower = response.lower()
 
-         if "generate_sql" in response_lower:
 
-             tool_calls.append({
 
-                 "name": "generate_sql",
 
-                 "args": {},
 
-                 "id": "generate_sql_call"
 
-             })
 
-         
 
-         return tool_calls
 
-     
 
-     @property
 
-     def model_name(self) -> str:
 
-         return self._model_name
 
-     
 
-     def bind_tools(self, tools):
 
-         """绑定工具(用于支持工具调用)"""
 
-         self._bound_tools = tools if isinstance(tools, list) else [tools]
 
-         return self
 
- def get_compatible_llm():
 
-     """获取兼容的LLM实例"""
 
-     try:
 
-         from common.utils import get_current_llm_config
 
-         llm_config = get_current_llm_config()
 
-         
 
-         # 尝试使用标准的OpenAI兼容API
 
-         if llm_config.get("base_url") and llm_config.get("api_key"):
 
-             try:
 
-                 from langchain_openai import ChatOpenAI
 
-                 llm = ChatOpenAI(
 
-                     base_url=llm_config.get("base_url"),
 
-                     api_key=llm_config.get("api_key"),
 
-                     model=llm_config.get("model"),
 
-                     temperature=llm_config.get("temperature", 0.7)
 
-                 )
 
-                 print("[INFO] 使用标准OpenAI兼容API")
 
-                 return llm
 
-             except ImportError:
 
-                 print("[WARNING] langchain_openai 未安装,使用 Vanna 实例包装器")
 
-         
 
-         # 优先使用统一的 Vanna 实例
 
-         from common.vanna_instance import get_vanna_instance
 
-         vn = get_vanna_instance()
 
-         print("[INFO] 使用Vanna实例包装器")
 
-         return LLMWrapper(vn)
 
-         
 
-     except Exception as e:
 
-         print(f"[ERROR] 获取 Vanna 实例失败: {str(e)}")
 
-         # 回退到原有逻辑
 
-         from common.utils import get_current_llm_config
 
-         from customllm.qianwen_chat import QianWenChat
 
-         
 
-         llm_config = get_current_llm_config()
 
-         custom_llm = QianWenChat(config=llm_config)
 
-         print("[INFO] 使用QianWen包装器")
 
-         return LLMWrapper(custom_llm)
 
- def _is_valid_sql_format(sql_text: str) -> bool:
 
-     """验证文本是否为有效的SQL查询格式"""
 
-     if not sql_text or not sql_text.strip():
 
-         return False
 
-     
 
-     sql_clean = sql_text.strip().upper()
 
-     
 
-     # 检查是否以SQL关键字开头
 
-     sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'WITH']
 
-     starts_with_sql = any(sql_clean.startswith(keyword) for keyword in sql_keywords)
 
-     
 
-     # 检查是否包含解释性语言
 
-     explanation_phrases = [
 
-         '无法', '不能', '抱歉', 'SORRY', 'UNABLE', 'CANNOT', 
 
-         '需要更多信息', '请提供', '表不存在', '字段不存在',
 
-         '不清楚', '不确定', '没有足够', '无法理解', '无法生成',
 
-         '无法确定', '不支持', '不可用', '缺少', '未找到'
 
-     ]
 
-     contains_explanation = any(phrase in sql_clean for phrase in explanation_phrases)
 
-     
 
-     return starts_with_sql and not contains_explanation
 
 
  |