sql_execution.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # agent/tools/sql_execution.py
  2. from langchain.tools import tool
  3. from typing import Dict, Any
  4. import pandas as pd
  5. import time
  6. import functools
  7. from common.vanna_instance import get_vanna_instance
  8. def retry_on_failure(max_retries: int = 2, delay: float = 1.0, backoff_factor: float = 2.0):
  9. """
  10. 重试装饰器
  11. Args:
  12. max_retries: 最大重试次数
  13. delay: 初始延迟时间(秒)
  14. backoff_factor: 退避因子(指数退避)
  15. """
  16. def decorator(func):
  17. @functools.wraps(func)
  18. def wrapper(*args, **kwargs):
  19. retries = 0
  20. while retries <= max_retries:
  21. try:
  22. result = func(*args, **kwargs)
  23. # 如果函数返回结果包含 can_retry 标识,检查是否需要重试
  24. if isinstance(result, dict) and result.get('can_retry', False) and not result.get('success', True):
  25. if retries < max_retries:
  26. retries += 1
  27. wait_time = delay * (backoff_factor ** (retries - 1))
  28. print(f"[RETRY] {func.__name__} 执行失败,等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  29. time.sleep(wait_time)
  30. continue
  31. return result
  32. except Exception as e:
  33. retries += 1
  34. if retries <= max_retries:
  35. wait_time = delay * (backoff_factor ** (retries - 1))
  36. print(f"[RETRY] {func.__name__} 异常: {str(e)}, 等待 {wait_time:.1f} 秒后重试 ({retries}/{max_retries})")
  37. time.sleep(wait_time)
  38. else:
  39. print(f"[RETRY] {func.__name__} 达到最大重试次数 ({max_retries}),抛出异常")
  40. raise
  41. # 不应该到达这里,但为了安全性
  42. return result
  43. return wrapper
  44. return decorator
  45. @tool
  46. @retry_on_failure(max_retries=2)
  47. def execute_sql(sql: str, max_rows: int = 200) -> Dict[str, Any]:
  48. """
  49. 执行SQL查询并返回结果。
  50. Args:
  51. sql: 要执行的SQL查询语句
  52. max_rows: 最大返回行数,默认200
  53. Returns:
  54. 包含查询结果的字典,格式:
  55. {
  56. "success": bool,
  57. "data_result": dict或None,
  58. "error": str或None,
  59. "can_retry": bool
  60. }
  61. """
  62. try:
  63. print(f"[TOOL:execute_sql] 开始执行SQL: {sql[:100]}...")
  64. vn = get_vanna_instance()
  65. df = vn.run_sql(sql)
  66. if df is None:
  67. return {
  68. "success": False,
  69. "data_result": None,
  70. "error": "SQL执行返回空结果",
  71. "error_type": "no_result",
  72. "can_retry": False
  73. }
  74. if not isinstance(df, pd.DataFrame):
  75. return {
  76. "success": False,
  77. "data_result": None,
  78. "error": f"SQL执行返回非DataFrame类型: {type(df)}",
  79. "error_type": "invalid_result_type",
  80. "can_retry": False
  81. }
  82. if df.empty:
  83. return {
  84. "success": True,
  85. "data_result": {
  86. "rows": [],
  87. "columns": [],
  88. "row_count": 0,
  89. "message": "查询执行成功,但没有找到符合条件的数据"
  90. },
  91. "message": "查询无结果"
  92. }
  93. # 处理数据结果
  94. total_rows = len(df)
  95. limited_df = df.head(max_rows)
  96. # 转换为字典格式并处理数据类型
  97. rows = _process_dataframe_rows(limited_df.to_dict(orient="records"))
  98. columns = list(df.columns)
  99. print(f"[TOOL:execute_sql] 查询成功,返回 {len(rows)} 行数据")
  100. result = {
  101. "success": True,
  102. "data_result": {
  103. "rows": rows,
  104. "columns": columns,
  105. "row_count": len(rows),
  106. "total_row_count": total_rows,
  107. "is_limited": total_rows > max_rows,
  108. "sql": sql
  109. },
  110. "message": f"查询成功,共 {total_rows} 行数据"
  111. }
  112. if total_rows > max_rows:
  113. result["message"] += f",已限制显示前 {max_rows} 行"
  114. return result
  115. except Exception as e:
  116. error_msg = str(e)
  117. print(f"[ERROR] SQL执行异常: {error_msg}")
  118. return {
  119. "success": False,
  120. "data_result": None,
  121. "error": f"SQL执行失败: {error_msg}",
  122. "error_type": _analyze_sql_error(error_msg),
  123. "can_retry": "timeout" in error_msg.lower() or "connection" in error_msg.lower(),
  124. "sql": sql
  125. }
  126. def _process_dataframe_rows(rows: list) -> list:
  127. """处理DataFrame行数据,确保JSON序列化兼容"""
  128. processed_rows = []
  129. for row in rows:
  130. processed_row = {}
  131. for key, value in row.items():
  132. if pd.isna(value):
  133. processed_row[key] = None
  134. elif isinstance(value, (pd.Timestamp, pd.Timedelta)):
  135. processed_row[key] = str(value)
  136. elif isinstance(value, (int, float, str, bool)):
  137. processed_row[key] = value
  138. else:
  139. processed_row[key] = str(value)
  140. processed_rows.append(processed_row)
  141. return processed_rows
  142. def _analyze_sql_error(error_msg: str) -> str:
  143. """分析SQL错误类型"""
  144. error_msg_lower = error_msg.lower()
  145. if "syntax error" in error_msg_lower or "syntaxerror" in error_msg_lower:
  146. return "syntax_error"
  147. elif "table" in error_msg_lower and ("not found" in error_msg_lower or "doesn't exist" in error_msg_lower):
  148. return "table_not_found"
  149. elif "column" in error_msg_lower and ("not found" in error_msg_lower or "unknown" in error_msg_lower):
  150. return "column_not_found"
  151. elif "timeout" in error_msg_lower:
  152. return "timeout"
  153. elif "connection" in error_msg_lower:
  154. return "connection_error"
  155. elif "permission" in error_msg_lower or "access denied" in error_msg_lower:
  156. return "permission_error"
  157. else:
  158. return "unknown_error"