sql_tools.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """
  2. 数据库查询相关的工具集
  3. """
  4. import re
  5. import json
  6. import logging
  7. from langchain_core.tools import tool
  8. from pydantic.v1 import BaseModel, Field
  9. from typing import List, Dict, Any
  10. import pandas as pd
  11. logger = logging.getLogger(__name__)
  12. # --- Pydantic Schema for Tool Arguments ---
  13. class GenerateSqlArgs(BaseModel):
  14. """Input schema for the generate_sql tool."""
  15. question: str = Field(description="The user's question to be converted to SQL.")
  16. history_messages: List[Dict[str, Any]] = Field(
  17. default=[],
  18. description="The conversation history messages for context."
  19. )
  20. # --- Tool Functions ---
  21. @tool(args_schema=GenerateSqlArgs)
  22. def generate_sql(question: str, history_messages: List[Dict[str, Any]] = None) -> str:
  23. """
  24. Generates an SQL query based on the user's question and the conversation history.
  25. """
  26. logger.info(f"🔧 [Tool] generate_sql - Question: '{question}'")
  27. if history_messages is None:
  28. history_messages = []
  29. logger.info(f" History contains {len(history_messages)} messages.")
  30. # Combine history and the current question to form a rich prompt
  31. if history_messages:
  32. history_str = "\n".join([f"{msg['type']}: {msg.get('content', '') or ''}" for msg in history_messages])
  33. enriched_question = f"""Previous conversation context:
  34. {history_str}
  35. Current user question: {question}
  36. Please analyze the conversation history to understand any references (like "this service area", "that branch", etc.) in the current question, and generate the appropriate SQL query."""
  37. else:
  38. # If no history messages, use the original question directly
  39. enriched_question = question
  40. try:
  41. from common.vanna_instance import get_vanna_instance
  42. vn = get_vanna_instance()
  43. sql = vn.generate_sql(enriched_question)
  44. if not sql or sql.strip() == "":
  45. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  46. error_info = vn.last_llm_explanation
  47. logger.warning(f" Vanna returned an explanation instead of SQL: {error_info}")
  48. return f"Database query failed. Reason: {error_info}"
  49. else:
  50. logger.warning(" Vanna failed to generate SQL and provided no explanation.")
  51. return "Could not generate SQL: The question may not be suitable for a database query."
  52. sql_upper = sql.upper().strip()
  53. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  54. logger.warning(f" Vanna returned a message that does not appear to be a valid SQL query: {sql}")
  55. return f"Database query failed. Reason: {sql}"
  56. logger.info(f" ✅ SQL Generated Successfully: {sql}")
  57. return sql
  58. except Exception as e:
  59. logger.error(f" An exception occurred during SQL generation: {e}", exc_info=True)
  60. return f"SQL generation failed: {str(e)}"
  61. @tool
  62. def valid_sql(sql: str) -> str:
  63. """
  64. 验证SQL语句的正确性和安全性。
  65. Args:
  66. sql: 待验证的SQL语句。
  67. Returns:
  68. 验证结果。
  69. """
  70. logger.info(f"🔧 [Tool] valid_sql - 待验证SQL (前100字符): {sql[:100]}...")
  71. if not sql or sql.strip() == "":
  72. logger.warning(" SQL验证失败:SQL语句为空。")
  73. return "SQL验证失败:SQL语句为空"
  74. sql_upper = sql.upper().strip()
  75. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  76. logger.warning(f" SQL验证失败:不是有效的查询语句。SQL: {sql}")
  77. return "SQL验证失败:不是有效的查询语句"
  78. # 简单的安全检查
  79. dangerous_patterns = [r'\bDROP\b', r'\bDELETE\b', r'\bTRUNCATE\b', r'\bALTER\b', r'\bCREATE\b', r'\bUPDATE\b']
  80. for pattern in dangerous_patterns:
  81. if re.search(pattern, sql_upper):
  82. keyword = pattern.replace(r'\b', '').replace('\\', '')
  83. logger.error(f" SQL验证失败:包含危险操作 {keyword}。SQL: {sql}")
  84. return f"SQL验证失败:包含危险操作 {keyword}"
  85. logger.info(f" ✅ SQL验证通过。")
  86. return "SQL验证通过:语法正确"
  87. @tool
  88. def run_sql(sql: str) -> str:
  89. """
  90. 执行SQL查询并以JSON字符串格式返回结果。
  91. Args:
  92. sql: 待执行的SQL语句。
  93. Returns:
  94. JSON字符串格式的查询结果,或包含错误的JSON字符串。
  95. """
  96. logger.info(f"🔧 [Tool] run_sql - 待执行SQL (前100字符): {sql[:100]}...")
  97. try:
  98. from common.vanna_instance import get_vanna_instance
  99. vn = get_vanna_instance()
  100. df = vn.run_sql(sql)
  101. print("-------------run_sql() df -------------------")
  102. print(df)
  103. print("--------------------------------")
  104. if df is None:
  105. logger.warning(" SQL执行成功,但查询结果为空。")
  106. result = {"status": "success", "data": [], "message": "查询无结果"}
  107. return json.dumps(result, ensure_ascii=False)
  108. logger.info(f" ✅ SQL执行成功,返回 {len(df)} 条记录。")
  109. # 将DataFrame转换为JSON,并妥善处理datetime等特殊类型
  110. return df.to_json(orient='records', date_format='iso')
  111. except Exception as e:
  112. logger.error(f" SQL执行过程中发生异常: {e}", exc_info=True)
  113. error_result = {"status": "error", "error_message": str(e)}
  114. return json.dumps(error_result, ensure_ascii=False)
  115. # 将所有工具函数收集到一个列表中,方便Agent导入和使用
  116. sql_tools = [generate_sql, valid_sql, run_sql]