sql_execution.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # agent/tools/sql_execution.py
  2. from langchain.tools import tool
  3. from typing import Dict, Any
  4. import pandas as pd
  5. import time
  6. import functools
  7. from common.vanna_instance import get_vanna_instance
  8. from app_config import API_MAX_RETURN_ROWS
  9. def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: float = 2.0):
  10. """
  11. 重试装饰器
  12. Args:
  13. max_retries: 最大重试次数
  14. delay: 初始延迟时间(秒)
  15. backoff_factor: 退避因子(指数退避)
  16. """
  17. def decorator(func):
  18. @functools.wraps(func)
  19. def wrapper(*args, **kwargs):
  20. retries = 0
  21. while retries <= max_retries:
  22. try:
  23. result = func(*args, **kwargs)
  24. # 如果函数返回结果包含 can_retry 标识,检查是否需要重试
  25. if isinstance(result, dict) and result.get('can_retry', False) and not result.get('success', True):
  26. if retries < max_retries:
  27. retries += 1
  28. wait_time = delay * (backoff_factor ** (retries - 1))
  29. print(f"[RETRY] {func.__name__} 执行失败,等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  30. time.sleep(wait_time)
  31. continue
  32. return result
  33. except Exception as e:
  34. retries += 1
  35. if retries <= max_retries:
  36. wait_time = delay * (backoff_factor ** (retries - 1))
  37. print(f"[RETRY] {func.__name__} 异常: {str(e)}, 等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  38. time.sleep(wait_time)
  39. else:
  40. print(f"[RETRY] {func.__name__} 达到最大重试次数 ({max_retries}),抛出异常")
  41. raise
  42. # 不应该到达这里,但为了安全性
  43. return result
  44. return wrapper
  45. return decorator
  46. @tool
  47. @retry_on_failure(max_retries=2)
  48. def execute_sql(sql: str, max_rows: int = None) -> Dict[str, Any]:
  49. """
  50. 执行SQL查询并返回结果。
  51. Args:
  52. sql: 要执行的SQL查询语句
  53. max_rows: 最大返回行数,默认使用API_MAX_RETURN_ROWS配置
  54. Returns:
  55. 包含查询结果的字典,格式:
  56. {
  57. "success": bool,
  58. "data_result": dict或None,
  59. "error": str或None,
  60. "can_retry": bool
  61. }
  62. """
  63. # 设置默认的最大返回行数,与ask()接口保持一致
  64. DEFAULT_MAX_RETURN_ROWS = 200
  65. if max_rows is None:
  66. max_rows = API_MAX_RETURN_ROWS if API_MAX_RETURN_ROWS is not None else DEFAULT_MAX_RETURN_ROWS
  67. try:
  68. print(f"[TOOL:execute_sql] 开始执行SQL: {sql[:100]}...")
  69. vn = get_vanna_instance()
  70. df = vn.run_sql(sql)
  71. if df is None:
  72. return {
  73. "success": False,
  74. "data_result": None,
  75. "error": "SQL执行返回空结果",
  76. "error_type": "no_result",
  77. "can_retry": False
  78. }
  79. if not isinstance(df, pd.DataFrame):
  80. return {
  81. "success": False,
  82. "data_result": None,
  83. "error": f"SQL执行返回非DataFrame类型: {type(df)}",
  84. "error_type": "invalid_result_type",
  85. "can_retry": False
  86. }
  87. if df.empty:
  88. return {
  89. "success": True,
  90. "data_result": {
  91. "rows": [],
  92. "columns": [],
  93. "row_count": 0,
  94. "message": "查询执行成功,但没有找到符合条件的数据"
  95. },
  96. "message": "查询无结果"
  97. }
  98. # 处理数据结果
  99. total_rows = len(df)
  100. limited_df = df.head(max_rows)
  101. # 转换为字典格式并处理数据类型
  102. rows = _process_dataframe_rows(limited_df.to_dict(orient="records"))
  103. columns = list(df.columns)
  104. print(f"[TOOL:execute_sql] 查询成功,返回 {len(rows)} 行数据")
  105. result = {
  106. "success": True,
  107. "data_result": {
  108. "rows": rows,
  109. "columns": columns,
  110. "row_count": len(rows),
  111. "total_row_count": total_rows,
  112. "is_limited": total_rows > max_rows,
  113. "sql": sql
  114. },
  115. "message": f"查询成功,共 {total_rows} 行数据"
  116. }
  117. if total_rows > max_rows:
  118. result["message"] += f",已限制显示前 {max_rows} 行"
  119. return result
  120. except Exception as e:
  121. error_msg = str(e)
  122. print(f"[ERROR] SQL执行异常: {error_msg}")
  123. return {
  124. "success": False,
  125. "data_result": None,
  126. "error": f"SQL执行失败: {error_msg}",
  127. "error_type": _analyze_sql_error(error_msg),
  128. "can_retry": "timeout" in error_msg.lower() or "connection" in error_msg.lower(),
  129. "sql": sql
  130. }
  131. def _process_dataframe_rows(rows: list) -> list:
  132. """处理DataFrame行数据,确保JSON序列化兼容"""
  133. processed_rows = []
  134. for row in rows:
  135. processed_row = {}
  136. for key, value in row.items():
  137. if pd.isna(value):
  138. processed_row[key] = None
  139. elif isinstance(value, (pd.Timestamp, pd.Timedelta)):
  140. processed_row[key] = str(value)
  141. elif isinstance(value, (int, float, str, bool)):
  142. processed_row[key] = value
  143. else:
  144. processed_row[key] = str(value)
  145. processed_rows.append(processed_row)
  146. return processed_rows
  147. def _analyze_sql_error(error_msg: str) -> str:
  148. """分析SQL错误类型"""
  149. error_msg_lower = error_msg.lower()
  150. if "syntax error" in error_msg_lower or "syntaxerror" in error_msg_lower:
  151. return "syntax_error"
  152. elif "table" in error_msg_lower and ("not found" in error_msg_lower or "doesn't exist" in error_msg_lower):
  153. return "table_not_found"
  154. elif "column" in error_msg_lower and ("not found" in error_msg_lower or "unknown" in error_msg_lower):
  155. return "column_not_found"
  156. elif "timeout" in error_msg_lower:
  157. return "timeout"
  158. elif "connection" in error_msg_lower:
  159. return "connection_error"
  160. elif "permission" in error_msg_lower or "access denied" in error_msg_lower:
  161. return "permission_error"
  162. else:
  163. return "unknown_error"