test_valid_sql_simple.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. """
  3. 简化版 valid_sql 测试脚本
  4. 只测试三种错误场景:table不存在、column不存在、语法错误
  5. """
  6. import asyncio
  7. import logging
  8. # 配置日志
  9. logging.basicConfig(
  10. level=logging.INFO,
  11. format='%(asctime)s - %(levelname)s - %(message)s'
  12. )
  13. logger = logging.getLogger(__name__)
  14. # 导入必要的模块
  15. try:
  16. from agent import CustomReactAgent
  17. from sql_tools import valid_sql
  18. from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
  19. except ImportError as e:
  20. logger.error(f"导入失败: {e}")
  21. logger.info("请确保在正确的目录下运行此脚本")
  22. exit(1)
  23. class SimpleValidSqlTester:
  24. """简化版 valid_sql 测试类"""
  25. def __init__(self):
  26. self.agent = None
  27. async def setup(self):
  28. """初始化 Agent"""
  29. logger.info("🚀 初始化 CustomReactAgent...")
  30. try:
  31. self.agent = await CustomReactAgent.create()
  32. logger.info("✅ Agent 初始化完成")
  33. except Exception as e:
  34. logger.error(f"❌ Agent 初始化失败: {e}")
  35. raise
  36. async def cleanup(self):
  37. """清理资源"""
  38. if self.agent:
  39. await self.agent.close()
  40. logger.info("✅ Agent 资源已清理")
  41. def test_valid_sql_direct(self, sql: str) -> str:
  42. """直接测试 valid_sql 工具"""
  43. logger.info(f"🔧 直接测试 valid_sql 工具")
  44. logger.info(f"SQL: {sql}")
  45. result = valid_sql(sql)
  46. logger.info(f"结果: {result}")
  47. return result
  48. async def test_llm_response_to_error(self, question: str, error_sql: str, error_message: str):
  49. """测试 LLM 对验证错误的响应"""
  50. logger.info(f"🧠 测试 LLM 对验证错误的响应")
  51. logger.info(f"问题: {question}")
  52. logger.info(f"错误SQL: {error_sql}")
  53. logger.info(f"错误信息: {error_message}")
  54. # 创建模拟的 state
  55. state = {
  56. "thread_id": "test_thread",
  57. "messages": [
  58. HumanMessage(content=question),
  59. ToolMessage(
  60. content=error_sql,
  61. name="generate_sql",
  62. tool_call_id="test_call_1"
  63. ),
  64. ToolMessage(
  65. content=error_message,
  66. name="valid_sql",
  67. tool_call_id="test_call_2"
  68. )
  69. ],
  70. "suggested_next_step": "analyze_validation_error"
  71. }
  72. try:
  73. # 调用 Agent 的内部方法来测试处理逻辑
  74. messages_for_llm = list(state["messages"])
  75. # 添加验证错误指导
  76. error_guidance = self.agent._generate_validation_error_guidance(error_message)
  77. messages_for_llm.append(SystemMessage(content=error_guidance))
  78. logger.info(f"📝 添加的错误指导: {error_guidance}")
  79. # 调用 LLM 看如何处理
  80. response = await self.agent.llm_with_tools.ainvoke(messages_for_llm)
  81. logger.info(f"🤖 LLM 响应: {response.content}")
  82. return response
  83. except Exception as e:
  84. logger.error(f"❌ 测试失败: {e}")
  85. return None
  86. async def test_three_scenarios():
  87. """测试三种错误场景"""
  88. logger.info("🧪 测试三种 valid_sql 错误场景")
  89. # 三种测试用例
  90. test_cases = [
  91. # {
  92. # "name": "表不存在",
  93. # "question": "查询员工表的信息",
  94. # "sql": "SELECT * FROM non_existent_table LIMIT 1"
  95. # },
  96. # {
  97. # "name": "字段不存在",
  98. # "question": "查询每个服务区的经理姓名",
  99. # "sql": "SELECT non_existent_field FROM bss_business_day_data LIMIT 1"
  100. # },
  101. {
  102. "name": "语法错误",
  103. "question": "查询服务区数据 WHERE",
  104. "sql": "SELECT service_name, pay_sum FROM bss_business_day_data WHERE service_name = '庐山服务区' AS service_alias"
  105. }
  106. ]
  107. tester = SimpleValidSqlTester()
  108. try:
  109. await tester.setup()
  110. for i, test_case in enumerate(test_cases, 1):
  111. logger.info(f"\n{'='*50}")
  112. logger.info(f"测试用例 {i}: {test_case['name']}")
  113. logger.info(f"{'='*50}")
  114. # 1. 直接测试 valid_sql
  115. direct_result = tester.test_valid_sql_direct(test_case["sql"])
  116. # 2. 测试 LLM 响应
  117. llm_response = await tester.test_llm_response_to_error(
  118. test_case["question"],
  119. test_case["sql"],
  120. direct_result
  121. )
  122. # 简单的结果分析
  123. logger.info(f"\n📊 结果分析:")
  124. if "失败" in direct_result:
  125. logger.info("✅ valid_sql 正确捕获错误")
  126. else:
  127. logger.warning("⚠️ valid_sql 可能未正确捕获错误")
  128. if llm_response and ("错误" in llm_response.content or "失败" in llm_response.content):
  129. logger.info("✅ LLM 正确处理验证错误")
  130. else:
  131. logger.warning("⚠️ LLM 可能未正确处理验证错误")
  132. except Exception as e:
  133. logger.error(f"❌ 测试失败: {e}")
  134. import traceback
  135. traceback.print_exc()
  136. finally:
  137. await tester.cleanup()
  138. async def main():
  139. """主函数"""
  140. logger.info("🚀 简化版 valid_sql 测试")
  141. await test_three_scenarios()
  142. logger.info("\n✅ 测试完成")
  143. if __name__ == "__main__":
  144. asyncio.run(main())