sql_execution.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # agent/tools/sql_execution.py
  2. from langchain.tools import tool
  3. from typing import Dict, Any
  4. import pandas as pd
  5. import time
  6. import functools
  7. from common.vanna_instance import get_vanna_instance
  8. from app_config import API_MAX_RETURN_ROWS
  9. from core.logging import get_agent_logger
  10. # Initialize logger
  11. logger = get_agent_logger("SQLExecution")
  12. def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: float = 2.0):
  13. """
  14. 重试装饰器
  15. Args:
  16. max_retries: 最大重试次数
  17. delay: 初始延迟时间(秒)
  18. backoff_factor: 退避因子(指数退避)
  19. """
  20. def decorator(func):
  21. @functools.wraps(func)
  22. def wrapper(*args, **kwargs):
  23. retries = 0
  24. while retries <= max_retries:
  25. try:
  26. result = func(*args, **kwargs)
  27. # 如果函数返回结果包含 can_retry 标识,检查是否需要重试
  28. if isinstance(result, dict) and result.get('can_retry', False) and not result.get('success', True):
  29. if retries < max_retries:
  30. retries += 1
  31. wait_time = delay * (backoff_factor ** (retries - 1))
  32. logger.warning(f"{func.__name__} 执行失败,等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  33. time.sleep(wait_time)
  34. continue
  35. return result
  36. except Exception as e:
  37. retries += 1
  38. if retries <= max_retries:
  39. wait_time = delay * (backoff_factor ** (retries - 1))
  40. logger.warning(f"{func.__name__} 异常: {str(e)}, 等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  41. time.sleep(wait_time)
  42. else:
  43. logger.error(f"{func.__name__} 达到最大重试次数 ({max_retries}),抛出异常")
  44. raise
  45. # 不应该到达这里,但为了安全性
  46. return result
  47. return wrapper
  48. return decorator
  49. @tool
  50. @retry_on_failure(max_retries=2)
  51. def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
  52. """
  53. 执行SQL查询并返回结果。
  54. Args:
  55. sql: 要执行的SQL查询语句
  56. max_rows: 最大返回行数,默认使用API_MAX_RETURN_ROWS配置
  57. Returns:
  58. 包含查询结果的字典,格式:
  59. {
  60. "success": bool,
  61. "data_result": dict或None, # 注意:工具内部仍使用data_result,但会被Agent重命名为query_result
  62. "error": str或None,
  63. "can_retry": bool
  64. }
  65. """
  66. # 设置默认的最大返回行数,与ask()接口保持一致
  67. DEFAULT_MAX_RETURN_ROWS = 200
  68. if max_rows is None:
  69. max_rows = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  70. try:
  71. logger.info(f"开始执行SQL: {sql[:100]}...")
  72. vn = get_vanna_instance()
  73. df = vn.run_sql(sql)
  74. if df is None:
  75. return {
  76. "success": False,
  77. "data_result": None,
  78. "error": "SQL执行返回空结果",
  79. "error_type": "no_result",
  80. "can_retry": False
  81. }
  82. if not isinstance(df, pd.DataFrame):
  83. return {
  84. "success": False,
  85. "data_result": None,
  86. "error": f"SQL执行返回非DataFrame类型: {type(df)}",
  87. "error_type": "invalid_result_type",
  88. "can_retry": False
  89. }
  90. if df.empty:
  91. return {
  92. "success": True,
  93. "data_result": {
  94. "rows": [],
  95. "columns": [],
  96. "row_count": 0,
  97. "message": "查询执行成功,但没有找到符合条件的数据"
  98. },
  99. "message": "查询无结果"
  100. }
  101. # 处理数据结果
  102. total_rows = len(df)
  103. limited_df = df.head(max_rows)
  104. # 转换为字典格式并处理数据类型
  105. rows = _process_dataframe_rows(limited_df.to_dict(orient="records"))
  106. columns = list(df.columns)
  107. logger.info(f"查询成功,返回 {len(rows)} 行数据")
  108. result = {
  109. "success": True,
  110. "data_result": {
  111. "rows": rows,
  112. "columns": columns,
  113. "row_count": len(rows),
  114. "total_row_count": total_rows,
  115. "is_limited": total_rows > max_rows
  116. },
  117. "message": f"查询成功,共 {total_rows} 行数据"
  118. }
  119. if total_rows > max_rows:
  120. result["message"] += f",已限制显示前 {max_rows} 行"
  121. return result
  122. except Exception as e:
  123. error_msg = str(e)
  124. logger.error(f"SQL执行异常: {error_msg}")
  125. return {
  126. "success": False,
  127. "data_result": None,
  128. "error": f"SQL执行失败: {error_msg}",
  129. "error_type": _analyze_sql_error(error_msg),
  130. "can_retry": "timeout" in error_msg.lower() or "connection" in error_msg.lower(),
  131. "sql": sql
  132. }
  133. def _process_dataframe_rows(rows: list) -> list:
  134. """处理DataFrame行数据,确保JSON序列化兼容"""
  135. processed_rows = []
  136. for row in rows:
  137. processed_row = {}
  138. for key, value in row.items():
  139. if pd.isna(value):
  140. processed_row[key] = None
  141. elif isinstance(value, (pd.Timestamp, pd.Timedelta)):
  142. processed_row[key] = str(value)
  143. elif isinstance(value, (int, float, str, bool)):
  144. processed_row[key] = value
  145. else:
  146. processed_row[key] = str(value)
  147. processed_rows.append(processed_row)
  148. return processed_rows
  149. def _analyze_sql_error(error_msg: str) -> str:
  150. """分析SQL错误类型"""
  151. error_msg_lower = error_msg.lower()
  152. if "syntax error" in error_msg_lower or "syntaxerror" in error_msg_lower:
  153. return "syntax_error"
  154. elif "table" in error_msg_lower and ("not found" in error_msg_lower or "doesn't exist" in error_msg_lower):
  155. return "table_not_found"
  156. elif "column" in error_msg_lower and ("not found" in error_msg_lower or "unknown" in error_msg_lower):
  157. return "column_not_found"
  158. elif "timeout" in error_msg_lower:
  159. return "timeout"
  160. elif "connection" in error_msg_lower:
  161. return "connection_error"
  162. elif "permission" in error_msg_lower or "access denied" in error_msg_lower:
  163. return "permission_error"
  164. else:
  165. return "unknown_error"