async_sql_tools.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. 异步版本的 SQL 工具 - 解决 Vector 搜索异步冲突
  3. 通过线程池执行同步操作,避免 LangGraph 事件循环冲突
  4. """
  5. import json
  6. import asyncio
  7. from typing import List, Dict, Any
  8. from concurrent.futures import ThreadPoolExecutor
  9. from langchain_core.tools import tool
  10. from pydantic import BaseModel, Field
  11. from core.logging import get_react_agent_logger
  12. logger = get_react_agent_logger("AsyncSQLTools")
  13. # 创建线程池执行器
  14. _executor = ThreadPoolExecutor(max_workers=3)
  15. class GenerateSqlArgs(BaseModel):
  16. question: str = Field(description="The user's question in natural language")
  17. history_messages: List[Dict[str, Any]] = Field(
  18. default_factory=list,
  19. description="The conversation history messages for context."
  20. )
  21. async def _run_in_executor(func, *args, **kwargs):
  22. """在线程池中运行同步函数,避免事件循环冲突"""
  23. loop = asyncio.get_event_loop()
  24. return await loop.run_in_executor(_executor, func, *args, **kwargs)
  25. @tool(args_schema=GenerateSqlArgs)
  26. async def generate_sql(question: str, history_messages: List[Dict[str, Any]] = None) -> str:
  27. """
  28. 异步生成 SQL 查询 - 通过线程池调用同步的 Vanna
  29. Generates an SQL query based on the user's question and the conversation history.
  30. """
  31. logger.info(f"🔧 [Async Tool] generate_sql - Question: '{question}'")
  32. # 在线程池中执行,避免事件循环冲突
  33. def _sync_generate():
  34. from common.vanna_instance import get_vanna_instance
  35. if history_messages is None:
  36. history_messages_local = []
  37. else:
  38. history_messages_local = history_messages
  39. logger.info(f" History contains {len(history_messages_local)} messages.")
  40. # 构建增强问题(与同步版本相同的逻辑)
  41. if history_messages_local:
  42. history_str = "\n".join([f"{msg['type']}: {msg.get('content', '') or ''}" for msg in history_messages_local])
  43. enriched_question = f"""Previous conversation context:
  44. {history_str}
  45. Current user question:
  46. human: {question}
  47. Please analyze the conversation history to understand any references (like "this service area", "that branch", etc.) in the current question, and generate the appropriate SQL query."""
  48. else:
  49. enriched_question = question
  50. # 记录 Vanna 输入
  51. logger.info("📝 [Async Vanna Input] Complete question being sent to Vanna:")
  52. logger.info("--- BEGIN VANNA INPUT ---")
  53. logger.info(enriched_question)
  54. logger.info("--- END VANNA INPUT ---")
  55. try:
  56. vn = get_vanna_instance()
  57. sql = vn.generate_sql(enriched_question)
  58. if not sql or sql.strip() == "":
  59. if hasattr(vn, 'last_llm_explanation') and vn.last_llm_explanation:
  60. error_info = vn.last_llm_explanation
  61. logger.warning(f" Vanna returned an explanation instead of SQL: {error_info}")
  62. return f"Database query failed. Reason: {error_info}"
  63. else:
  64. logger.warning(" Vanna failed to generate SQL and provided no explanation.")
  65. return "Could not generate SQL: The question may not be suitable for a database query."
  66. sql_upper = sql.upper().strip()
  67. if not any(keyword in sql_upper for keyword in ['SELECT', 'WITH']):
  68. logger.warning(f" Vanna returned a message that does not appear to be a valid SQL query: {sql}")
  69. return f"Database query failed. Reason: {sql}"
  70. logger.info(f" ✅ SQL Generated Successfully:")
  71. logger.info(f" {sql}")
  72. return sql
  73. except Exception as e:
  74. logger.error(f" An exception occurred during SQL generation: {e}", exc_info=True)
  75. return f"SQL generation failed: {str(e)}"
  76. # 在线程池中执行
  77. return await _run_in_executor(_sync_generate)
  78. # 导入同步版本的验证函数
  79. def _import_validation_functions():
  80. """动态导入验证函数,避免循环导入"""
  81. from react_agent.sql_tools import _check_basic_syntax, _check_table_existence, _validate_with_limit_zero
  82. return _check_basic_syntax, _check_table_existence, _validate_with_limit_zero
  83. @tool
  84. async def valid_sql(sql: str) -> str:
  85. """
  86. 异步验证 SQL 语句的有效性
  87. Validates the SQL statement by checking syntax and executing with LIMIT 0.
  88. """
  89. logger.info(f"🔧 [Async Tool] valid_sql - Validating SQL")
  90. def _sync_validate():
  91. # 导入验证函数
  92. _check_basic_syntax, _check_table_existence, _validate_with_limit_zero = _import_validation_functions()
  93. # 规则1:基本语法检查
  94. if not _check_basic_syntax(sql):
  95. logger.warning(f" SQL基本语法检查失败: {sql[:100]}...")
  96. return json.dumps({
  97. "result": "invalid",
  98. "error": "SQL语句格式错误:必须是SELECT或WITH开头的查询语句"
  99. })
  100. # 规则2:表存在性检查
  101. if not _check_table_existence(sql):
  102. logger.warning(f" SQL表存在性检查失败")
  103. return json.dumps({
  104. "result": "invalid",
  105. "error": "SQL中引用的表不存在于数据库中"
  106. })
  107. # 规则3:LIMIT 0执行测试
  108. return _validate_with_limit_zero(sql)
  109. return await _run_in_executor(_sync_validate)
  110. @tool
  111. async def run_sql(sql: str) -> str:
  112. """
  113. 异步执行 SQL 查询并返回结果
  114. 执行SQL查询并以JSON字符串格式返回结果。
  115. Args:
  116. sql: 待执行的SQL语句。
  117. Returns:
  118. JSON字符串格式的查询结果,或包含错误的JSON字符串。
  119. """
  120. logger.info(f"🔧 [Async Tool] run_sql - 待执行SQL:")
  121. logger.info(f" {sql}")
  122. def _sync_run():
  123. from common.vanna_instance import get_vanna_instance
  124. try:
  125. vn = get_vanna_instance()
  126. df = vn.run_sql(sql)
  127. logger.debug(f"SQL执行结果:\n{df}")
  128. if df is None:
  129. logger.warning(" SQL执行成功,但查询结果为空。")
  130. result = {"status": "success", "data": [], "message": "查询无结果"}
  131. return json.dumps(result, ensure_ascii=False)
  132. logger.info(f" ✅ SQL执行成功,返回 {len(df)} 条记录。")
  133. # 将DataFrame转换为JSON,并妥善处理datetime等特殊类型
  134. return df.to_json(orient='records', date_format='iso')
  135. except Exception as e:
  136. logger.error(f" SQL执行过程中发生异常: {e}", exc_info=True)
  137. error_result = {"status": "error", "error_message": str(e)}
  138. return json.dumps(error_result, ensure_ascii=False)
  139. return await _run_in_executor(_sync_run)
  140. # 将所有异步工具函数收集到一个列表中
  141. async_sql_tools = [generate_sql, valid_sql, run_sql]
  142. # 清理函数(可选)
  143. def cleanup():
  144. """清理线程池资源"""
  145. global _executor
  146. if _executor:
  147. _executor.shutdown(wait=False)
  148. logger.info("异步SQL工具线程池已关闭")