123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- #!/usr/bin/env python3
- """
- 简化版 valid_sql 测试脚本
- 只测试三种错误场景:table不存在、column不存在、语法错误
- """
- import asyncio
- import logging
- # 配置日志
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger(__name__)
- # 导入必要的模块
- try:
- from agent import CustomReactAgent
- from sql_tools import valid_sql
- from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
- except ImportError as e:
- logger.error(f"导入失败: {e}")
- logger.info("请确保在正确的目录下运行此脚本")
- exit(1)
- class SimpleValidSqlTester:
- """简化版 valid_sql 测试类"""
-
- def __init__(self):
- self.agent = None
-
- async def setup(self):
- """初始化 Agent"""
- logger.info("🚀 初始化 CustomReactAgent...")
- try:
- self.agent = await CustomReactAgent.create()
- logger.info("✅ Agent 初始化完成")
- except Exception as e:
- logger.error(f"❌ Agent 初始化失败: {e}")
- raise
-
- async def cleanup(self):
- """清理资源"""
- if self.agent:
- await self.agent.close()
- logger.info("✅ Agent 资源已清理")
-
- def test_valid_sql_direct(self, sql: str) -> str:
- """直接测试 valid_sql 工具"""
- logger.info(f"🔧 直接测试 valid_sql 工具")
- logger.info(f"SQL: {sql}")
-
- result = valid_sql(sql)
- logger.info(f"结果: {result}")
- return result
-
- async def test_llm_response_to_error(self, question: str, error_sql: str, error_message: str):
- """测试 LLM 对验证错误的响应"""
- logger.info(f"🧠 测试 LLM 对验证错误的响应")
- logger.info(f"问题: {question}")
- logger.info(f"错误SQL: {error_sql}")
- logger.info(f"错误信息: {error_message}")
-
- # 创建模拟的 state
- state = {
- "thread_id": "test_thread",
- "messages": [
- HumanMessage(content=question),
- ToolMessage(
- content=error_sql,
- name="generate_sql",
- tool_call_id="test_call_1"
- ),
- ToolMessage(
- content=error_message,
- name="valid_sql",
- tool_call_id="test_call_2"
- )
- ],
- "suggested_next_step": "analyze_validation_error"
- }
-
- try:
- # 调用 Agent 的内部方法来测试处理逻辑
- messages_for_llm = list(state["messages"])
-
- # 添加验证错误指导
- error_guidance = self.agent._generate_validation_error_guidance(error_message)
- messages_for_llm.append(SystemMessage(content=error_guidance))
-
- logger.info(f"📝 添加的错误指导: {error_guidance}")
-
- # 调用 LLM 看如何处理
- response = await self.agent.llm_with_tools.ainvoke(messages_for_llm)
- logger.info(f"🤖 LLM 响应: {response.content}")
-
- return response
-
- except Exception as e:
- logger.error(f"❌ 测试失败: {e}")
- return None
- async def test_three_scenarios():
- """测试三种错误场景"""
- logger.info("🧪 测试三种 valid_sql 错误场景")
-
- # 三种测试用例
- test_cases = [
- # {
- # "name": "表不存在",
- # "question": "查询员工表的信息",
- # "sql": "SELECT * FROM non_existent_table LIMIT 1"
- # },
- # {
- # "name": "字段不存在",
- # "question": "查询每个服务区的经理姓名",
- # "sql": "SELECT non_existent_field FROM bss_business_day_data LIMIT 1"
- # },
- {
- "name": "语法错误",
- "question": "查询服务区数据 WHERE",
- "sql": "SELECT service_name, pay_sum FROM bss_business_day_data WHERE service_name = '庐山服务区' AS service_alias"
- }
- ]
-
- tester = SimpleValidSqlTester()
-
- try:
- await tester.setup()
-
- for i, test_case in enumerate(test_cases, 1):
- logger.info(f"\n{'='*50}")
- logger.info(f"测试用例 {i}: {test_case['name']}")
- logger.info(f"{'='*50}")
-
- # 1. 直接测试 valid_sql
- direct_result = tester.test_valid_sql_direct(test_case["sql"])
-
- # 2. 测试 LLM 响应
- llm_response = await tester.test_llm_response_to_error(
- test_case["question"],
- test_case["sql"],
- direct_result
- )
-
- # 简单的结果分析
- logger.info(f"\n📊 结果分析:")
- if "失败" in direct_result:
- logger.info("✅ valid_sql 正确捕获错误")
- else:
- logger.warning("⚠️ valid_sql 可能未正确捕获错误")
-
- if llm_response and ("错误" in llm_response.content or "失败" in llm_response.content):
- logger.info("✅ LLM 正确处理验证错误")
- else:
- logger.warning("⚠️ LLM 可能未正确处理验证错误")
-
- except Exception as e:
- logger.error(f"❌ 测试失败: {e}")
- import traceback
- traceback.print_exc()
-
- finally:
- await tester.cleanup()
- async def main():
- """主函数"""
- logger.info("🚀 简化版 valid_sql 测试")
- await test_three_scenarios()
- logger.info("\n✅ 测试完成")
- if __name__ == "__main__":
- asyncio.run(main())
|