sql_generation.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # agent/tools/sql_generation.py
  2. from langchain.tools import tool
  3. from typing import Dict, Any
  4. from common.vanna_instance import get_vanna_instance
  5. @tool
  6. def generate_sql(question: str, allow_llm_to_see_data: bool = True) -> Dict[str, Any]:
  7. """
  8. 将自然语言问题转换为SQL查询。
  9. Args:
  10. question: 需要转换为SQL的自然语言问题
  11. allow_llm_to_see_data: 是否允许LLM查看数据,默认True
  12. Returns:
  13. 包含SQL生成结果的字典,格式:
  14. {
  15. "success": bool,
  16. "sql": str或None,
  17. "error": str或None,
  18. "can_retry": bool
  19. }
  20. """
  21. try:
  22. print(f"[TOOL:generate_sql] 开始生成SQL: {question}")
  23. vn = get_vanna_instance()
  24. sql = vn.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
  25. if sql is None:
  26. # 检查是否有LLM解释性文本
  27. explanation = getattr(vn, 'last_llm_explanation', None)
  28. if explanation:
  29. return {
  30. "success": False,
  31. "sql": None,
  32. "error": explanation,
  33. "error_type": "generation_failed_with_explanation",
  34. "can_retry": True
  35. }
  36. else:
  37. return {
  38. "success": False,
  39. "sql": None,
  40. "error": "无法生成SQL查询,可能是问题描述不够明确或数据表结构不匹配",
  41. "error_type": "generation_failed",
  42. "can_retry": True
  43. }
  44. # 检查SQL质量
  45. sql_clean = sql.strip()
  46. if not sql_clean:
  47. return {
  48. "success": False,
  49. "sql": sql,
  50. "error": "生成的SQL为空",
  51. "error_type": "empty_sql",
  52. "can_retry": True
  53. }
  54. # 检查是否返回了错误信息而非SQL
  55. error_indicators = [
  56. "insufficient context", "无法生成", "sorry", "cannot generate",
  57. "not enough information", "unclear", "unable to"
  58. ]
  59. if any(indicator in sql_clean.lower() for indicator in error_indicators):
  60. return {
  61. "success": False,
  62. "sql": None,
  63. "error": sql_clean,
  64. "error_type": "llm_explanation",
  65. "can_retry": False
  66. }
  67. print(f"[TOOL:generate_sql] 成功生成SQL: {sql}")
  68. return {
  69. "success": True,
  70. "sql": sql,
  71. "error": None,
  72. "message": "SQL生成成功"
  73. }
  74. except Exception as e:
  75. print(f"[ERROR] SQL生成异常: {str(e)}")
  76. return {
  77. "success": False,
  78. "sql": None,
  79. "error": f"SQL生成过程异常: {str(e)}",
  80. "error_type": "exception",
  81. "can_retry": True
  82. }