sql_tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. """
  2. 数据库查询相关的工具集
  3. """
  4. import re
  5. import json
  6. import sys
  7. import os
  8. from pathlib import Path
  9. from langchain_core.tools import tool
  10. from pydantic.v1 import BaseModel, Field
  11. from typing import List, Dict, Any
  12. import pandas as pd
  13. # 添加项目根目录到sys.path以解决common模块导入问题
  14. try:
  15. project_root = Path(__file__).parent.parent
  16. if str(project_root) not in sys.path:
  17. sys.path.insert(0, str(project_root))
  18. except Exception as e:
  19. pass # 忽略路径添加错误
  20. # 使用统一日志系统
  21. try:
  22. # 尝试相对导入(当作为模块导入时)
  23. from core.logging import get_react_agent_logger
  24. except ImportError:
  25. # 如果相对导入失败,尝试绝对导入(直接运行时)
  26. from core.logging import get_react_agent_logger
  27. logger = get_react_agent_logger("SQLTools")
  28. # --- Pydantic Schema for Tool Arguments ---
  29. class GenerateSqlArgs(BaseModel):
  30. """Input schema for the generate_sql tool."""
  31. question: str = Field(description="The user's question to be converted to SQL.")
  32. history_messages: List[Dict[str, Any]] = Field(
  33. default=[],
  34. description="The conversation history messages for context."
  35. )
  36. # --- Tool Functions ---
  37. @tool(args_schema=GenerateSqlArgs)
  38. def generate_sql(question: str, history_messages: List[Dict[str, Any]] = None) -> str:
  39. """
  40. Generates an SQL query based on the user's question and the conversation history.
  41. """
  42. logger.info(f"🔧 [Tool] generate_sql - Question: '{question}'")
  43. if history_messages is None:
  44. history_messages = []
  45. logger.info(f" History contains {len(history_messages)} messages.")
  46. # Combine history and the current question to form a rich prompt
  47. if history_messages:
  48. history_str = "\n".join([f"{msg['type']}: {msg.get('content', '') or ''}" for msg in history_messages])
  49. enriched_question = f"""Previous conversation context:
  50. {history_str}
  51. Current user question:
  52. human: {question}
  53. Please analyze the conversation history to understand any references (like "this service area", "that branch", etc.) in the current question, and generate the appropriate SQL query."""
  54. else:
  55. # If no history messages, use the original question directly
  56. enriched_question = question
  57. # 🎯 添加稳定的Vanna输入日志
  58. logger.info("📝 [Vanna Input] Complete question being sent to Vanna:")
  59. logger.info("--- BEGIN VANNA INPUT ---")
  60. logger.info(enriched_question)
  61. logger.info("--- END VANNA INPUT ---")
  62. try:
  63. from common.vanna_instance import get_vanna_instance
  64. vn = get_vanna_instance()
  65. sql = vn.generate_sql(enriched_question)
  66. if not sql or sql.strip() == "":
  67. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  68. error_info = vn.last_llm_explanation
  69. logger.warning(f" Vanna returned an explanation instead of SQL: {error_info}")
  70. return f"Database query failed. Reason: {error_info}"
  71. else:
  72. logger.warning(" Vanna failed to generate SQL and provided no explanation.")
  73. return "Could not generate SQL: The question may not be suitable for a database query."
  74. sql_upper = sql.upper().strip()
  75. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  76. logger.warning(f" Vanna returned a message that does not appear to be a valid SQL query: {sql}")
  77. return f"Database query failed. Reason: {sql}"
  78. logger.info(f" ✅ SQL Generated Successfully:")
  79. logger.info(f" {sql}")
  80. return sql
  81. except Exception as e:
  82. logger.error(f" An exception occurred during SQL generation: {e}", exc_info=True)
  83. return f"SQL generation failed: {str(e)}"
  84. def _check_basic_syntax(sql: str) -> bool:
  85. """规则1: 检查SQL是否包含基础查询关键词"""
  86. if not sql or sql.strip() == "":
  87. return False
  88. sql_upper = sql.upper().strip()
  89. return any(keyword in sql_upper for keyword in ['SELECT', 'WITH'])
  90. def _check_security(sql: str) -> tuple[bool, str]:
  91. """规则2: 检查SQL是否包含危险操作
  92. Returns:
  93. tuple: (是否安全, 错误信息)
  94. """
  95. sql_upper = sql.upper().strip()
  96. dangerous_patterns = [r'\bDROP\b', r'\bDELETE\b', r'\bTRUNCATE\b', r'\bALTER\b', r'\bCREATE\b', r'\bUPDATE\b']
  97. for pattern in dangerous_patterns:
  98. if re.search(pattern, sql_upper):
  99. keyword = pattern.replace(r'\b', '').replace('\\', '')
  100. return False, f"包含危险操作 {keyword}"
  101. return True, ""
  102. def _has_limit_clause(sql: str) -> bool:
  103. """检测SQL是否包含LIMIT子句"""
  104. # 使用正则表达式检测LIMIT关键词,支持多种格式
  105. # LIMIT n 或 LIMIT offset, count 格式
  106. limit_pattern = r'\bLIMIT\s+\d+(?:\s*,\s*\d+)?\s*(?:;|\s*$)'
  107. return bool(re.search(limit_pattern, sql, re.IGNORECASE))
  108. def _validate_with_limit_zero(sql: str) -> str:
  109. """规则3: 使用LIMIT 0验证SQL(适用于无LIMIT子句的SQL)"""
  110. try:
  111. from common.vanna_instance import get_vanna_instance
  112. vn = get_vanna_instance()
  113. # 添加 LIMIT 0 避免返回大量数据,只验证SQL结构
  114. test_sql = sql.rstrip(';') + " LIMIT 0"
  115. logger.info(f" 执行LIMIT 0验证:")
  116. logger.info(f" {test_sql}")
  117. vn.run_sql(test_sql)
  118. logger.info(" ✅ SQL验证通过:语法正确且字段/表存在")
  119. return "SQL验证通过:语法正确且字段存在"
  120. except Exception as e:
  121. return _format_validation_error(str(e))
  122. def _validate_with_prepare(sql: str) -> str:
  123. """规则4: 使用PREPARE/DEALLOCATE验证SQL(适用于包含LIMIT子句的SQL)"""
  124. import time
  125. try:
  126. from common.vanna_instance import get_vanna_instance
  127. vn = get_vanna_instance()
  128. # 生成唯一的语句名,避免并发冲突
  129. stmt_name = f"validation_stmt_{int(time.time() * 1000)}"
  130. prepare_executed = False
  131. try:
  132. # 执行PREPARE验证
  133. prepare_sql = f"PREPARE {stmt_name} AS {sql.rstrip(';')}"
  134. logger.info(f" 执行PREPARE验证:")
  135. logger.info(f" {prepare_sql}")
  136. vn.run_sql(prepare_sql)
  137. prepare_executed = True
  138. # 如果执行到这里没有异常,说明PREPARE成功
  139. logger.info(" ✅ PREPARE执行成功,SQL验证通过")
  140. return "SQL验证通过:语法正确且字段存在"
  141. except Exception as e:
  142. error_msg = str(e).lower()
  143. # PostgreSQL中PREPARE不返回结果集是正常行为
  144. if "no results to fetch" in error_msg:
  145. prepare_executed = True # 标记为成功执行
  146. logger.info(" ✅ PREPARE执行成功(无结果集),SQL验证通过")
  147. return "SQL验证通过:语法正确且字段存在"
  148. else:
  149. # 真正的错误(语法错误、字段不存在等)
  150. raise e
  151. finally:
  152. # 只有在PREPARE成功执行时才尝试清理资源
  153. if prepare_executed:
  154. try:
  155. deallocate_sql = f"DEALLOCATE {stmt_name}"
  156. logger.info(f" 清理PREPARE资源: {deallocate_sql}")
  157. vn.run_sql(deallocate_sql)
  158. except Exception as cleanup_error:
  159. # 清理失败不影响验证结果,只记录警告
  160. logger.warning(f" 清理PREPARE资源失败: {cleanup_error}")
  161. except Exception as e:
  162. return _format_validation_error(str(e))
  163. def _format_validation_error(error_msg: str) -> str:
  164. """格式化验证错误信息"""
  165. logger.warning(f" SQL验证失败:执行测试时出错 - {error_msg}")
  166. # 提供更详细的错误信息供LLM理解和处理
  167. if "column" in error_msg.lower() and ("does not exist" in error_msg.lower() or "不存在" in error_msg):
  168. return f"SQL验证失败:字段不存在。详细错误:{error_msg}"
  169. elif "table" in error_msg.lower() and ("does not exist" in error_msg.lower() or "不存在" in error_msg):
  170. return f"SQL验证失败:表不存在。详细错误:{error_msg}"
  171. elif "syntax error" in error_msg.lower() or "语法错误" in error_msg:
  172. return f"SQL验证失败:语法错误。详细错误:{error_msg}"
  173. else:
  174. return f"SQL验证失败:执行失败。详细错误:{error_msg}"
  175. @tool
  176. def valid_sql(sql: str) -> str:
  177. """
  178. 验证SQL语句的正确性和安全性,使用四规则递进验证:
  179. 1. 基础语法检查(SELECT/WITH关键词)
  180. 2. 安全检查(无危险操作)
  181. 3. 语义验证:无LIMIT时使用LIMIT 0验证
  182. 4. 语义验证:有LIMIT时使用PREPARE/DEALLOCATE验证
  183. Args:
  184. sql: 待验证的SQL语句。
  185. Returns:
  186. 验证结果。
  187. """
  188. logger.info(f"🔧 [Tool] valid_sql - 待验证SQL:")
  189. logger.info(f" {sql}")
  190. # 规则1: 基础语法检查
  191. if not _check_basic_syntax(sql):
  192. logger.warning(" SQL验证失败:SQL语句为空或不是有效的查询语句")
  193. return "SQL验证失败:SQL语句为空或不是有效的查询语句"
  194. # 规则2: 安全检查
  195. is_safe, security_error = _check_security(sql)
  196. if not is_safe:
  197. logger.error(f" SQL验证失败:{security_error}")
  198. return f"SQL验证失败:{security_error}"
  199. # 规则3/4: 语义验证(二选一)
  200. if _has_limit_clause(sql):
  201. logger.info(" 检测到LIMIT子句,使用PREPARE验证")
  202. return _validate_with_prepare(sql)
  203. else:
  204. logger.info(" 未检测到LIMIT子句,使用LIMIT 0验证")
  205. return _validate_with_limit_zero(sql)
  206. @tool
  207. def run_sql(sql: str) -> str:
  208. """
  209. 执行SQL查询并以JSON字符串格式返回结果。
  210. Args:
  211. sql: 待执行的SQL语句。
  212. Returns:
  213. JSON字符串格式的查询结果,或包含错误的JSON字符串。
  214. """
  215. logger.info(f"🔧 [Tool] run_sql - 待执行SQL:")
  216. logger.info(f" {sql}")
  217. try:
  218. from common.vanna_instance import get_vanna_instance
  219. vn = get_vanna_instance()
  220. df = vn.run_sql(sql)
  221. logger.debug(f"SQL执行结果:\n{df}")
  222. if df is None:
  223. logger.warning(" SQL执行成功,但查询结果为空。")
  224. result = {"status": "success", "data": [], "message": "查询无结果"}
  225. return json.dumps(result, ensure_ascii=False)
  226. logger.info(f" ✅ SQL执行成功,返回 {len(df)} 条记录。")
  227. # 将DataFrame转换为JSON,并妥善处理datetime等特殊类型
  228. return df.to_json(orient='records', date_format='iso')
  229. except Exception as e:
  230. logger.error(f" SQL执行过程中发生异常: {e}", exc_info=True)
  231. error_result = {"status": "error", "error_message": str(e)}
  232. return json.dumps(error_result, ensure_ascii=False)
  233. # 将所有工具函数收集到一个列表中,方便Agent导入和使用
  234. sql_tools = [generate_sql, valid_sql, run_sql]