sql_tools.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """
  2. 数据库查询相关的工具集
  3. """
  4. import re
  5. import json
  6. import logging
  7. from langchain_core.tools import tool
  8. import pandas as pd
  9. logger = logging.getLogger(__name__)
  10. # --- 工具函数 ---
  11. @tool
  12. def generate_sql(question: str) -> str:
  13. """
  14. 根据用户问题生成SQL查询语句。
  15. Args:
  16. question: 用户的原始问题。
  17. Returns:
  18. 生成的SQL语句或错误信息。
  19. """
  20. logger.info(f"🔧 [Tool] generate_sql - 问题: '{question}'")
  21. try:
  22. from common.vanna_instance import get_vanna_instance
  23. vn = get_vanna_instance()
  24. sql = vn.generate_sql(question)
  25. if not sql or sql.strip() == "":
  26. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  27. error_info = vn.last_llm_explanation
  28. logger.warning(f" Vanna返回了错误解释: {error_info}")
  29. return f"数据库查询失败,具体原因:{error_info}"
  30. else:
  31. logger.warning(" Vanna未能生成SQL且无解释。")
  32. return "无法生成SQL:问题可能不适合数据库查询"
  33. sql_upper = sql.upper().strip()
  34. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  35. logger.warning(f" Vanna返回了疑似错误信息而非SQL: {sql}")
  36. return f"数据库查询失败,具体原因:{sql}"
  37. logger.info(f" ✅ 成功生成SQL: {sql}")
  38. return sql
  39. except Exception as e:
  40. logger.error(f" SQL生成过程中发生异常: {e}", exc_info=True)
  41. return f"SQL生成失败: {str(e)}"
  42. @tool
  43. def valid_sql(sql: str) -> str:
  44. """
  45. 验证SQL语句的正确性和安全性。
  46. Args:
  47. sql: 待验证的SQL语句。
  48. Returns:
  49. 验证结果。
  50. """
  51. logger.info(f"🔧 [Tool] valid_sql - 待验证SQL (前100字符): {sql[:100]}...")
  52. if not sql or sql.strip() == "":
  53. logger.warning(" SQL验证失败:SQL语句为空。")
  54. return "SQL验证失败:SQL语句为空"
  55. sql_upper = sql.upper().strip()
  56. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  57. logger.warning(f" SQL验证失败:不是有效的查询语句。SQL: {sql}")
  58. return "SQL验证失败:不是有效的查询语句"
  59. # 简单的安全检查
  60. dangerous_patterns = [r'\bDROP\b', r'\bDELETE\b', r'\bTRUNCATE\b', r'\bALTER\b', r'\bCREATE\b', r'\bUPDATE\b']
  61. for pattern in dangerous_patterns:
  62. if re.search(pattern, sql_upper):
  63. keyword = pattern.replace(r'\b', '').replace('\\', '')
  64. logger.error(f" SQL验证失败:包含危险操作 {keyword}。SQL: {sql}")
  65. return f"SQL验证失败:包含危险操作 {keyword}"
  66. logger.info(f" ✅ SQL验证通过。")
  67. return "SQL验证通过:语法正确"
  68. @tool
  69. def run_sql(sql: str) -> str:
  70. """
  71. 执行SQL查询并以JSON字符串格式返回结果。
  72. Args:
  73. sql: 待执行的SQL语句。
  74. Returns:
  75. JSON字符串格式的查询结果,或包含错误的JSON字符串。
  76. """
  77. logger.info(f"🔧 [Tool] run_sql - 待执行SQL (前100字符): {sql[:100]}...")
  78. try:
  79. from common.vanna_instance import get_vanna_instance
  80. vn = get_vanna_instance()
  81. df = vn.run_sql(sql)
  82. if df is None:
  83. logger.warning(" SQL执行成功,但查询结果为空。")
  84. result = {"status": "success", "data": [], "message": "查询无结果"}
  85. return json.dumps(result, ensure_ascii=False)
  86. logger.info(f" ✅ SQL执行成功,返回 {len(df)} 条记录。")
  87. # 将DataFrame转换为JSON,并妥善处理datetime等特殊类型
  88. return df.to_json(orient='records', date_format='iso')
  89. except Exception as e:
  90. logger.error(f" SQL执行过程中发生异常: {e}", exc_info=True)
  91. error_result = {"status": "error", "error_message": str(e)}
  92. return json.dumps(error_result, ensure_ascii=False)
  93. # 将所有工具函数收集到一个列表中,方便Agent导入和使用
  94. sql_tools = [generate_sql, valid_sql, run_sql]