sql_generation.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # agent/tools/sql_generation.py
  2. from langchain.tools import tool
  3. from typing import Dict, Any
  4. from common.vanna_instance import get_vanna_instance
  5. @tool
  6. def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str, Any]:
  7. """
  8. 将自然语言问题转换为SQL查询。
  9. Args:
  10. question: 需要转换为SQL的自然语言问题
  11. allow_llm_to_see_data: 是否允许LLM查看数据,默认True
  12. Returns:
  13. 包含SQL生成结果的字典,格式:
  14. {
  15. "success": bool,
  16. "sql": str或None,
  17. "error": str或None,
  18. "can_retry": bool
  19. }
  20. """
  21. try:
  22. print(f"[TOOL:generate_sql] 开始生成SQL: {question}")
  23. vn = get_vanna_instance()
  24. sql = vn.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
  25. if sql is None:
  26. # 检查是否有LLM解释性文本(已在base_llm_chat.py中处理thinking内容)
  27. explanation = getattr(vn, 'last_llm_explanation', None)
  28. if explanation:
  29. return {
  30. "success": False,
  31. "sql": None,
  32. "error": explanation,
  33. "error_type": "generation_failed_with_explanation",
  34. "can_retry": True
  35. }
  36. else:
  37. return {
  38. "success": False,
  39. "sql": None,
  40. "error": "无法生成SQL查询,可能是问题描述不够明确或数据表结构不匹配",
  41. "error_type": "generation_failed",
  42. "can_retry": True
  43. }
  44. # 检查SQL质量
  45. sql_clean = sql.strip()
  46. if not sql_clean:
  47. return {
  48. "success": False,
  49. "sql": sql,
  50. "error": "生成的SQL为空",
  51. "error_type": "empty_sql",
  52. "can_retry": True
  53. }
  54. print(f"[TOOL:generate_sql] 成功生成SQL: {sql}")
  55. return {
  56. "success": True,
  57. "sql": sql,
  58. "error": None,
  59. "message": "SQL生成成功"
  60. }
  61. except Exception as e:
  62. print(f"[ERROR] SQL生成异常: {str(e)}")
  63. return {
  64. "success": False,
  65. "sql": None,
  66. "error": f"SQL生成过程异常: {str(e)}",
  67. "error_type": "exception",
  68. "can_retry": True
  69. }