test_schema_tools.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """
  2. 测试Schema Tools模块
  3. """
  4. import asyncio
  5. import os
  6. import sys
  7. from pathlib import Path
  8. # 添加项目根目录到Python路径
  9. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  10. async def test_basic_functionality():
  11. """测试基本功能"""
  12. print("===== 测试 Schema Tools =====")
  13. # 1. 测试配置
  14. from schema_tools.config import SCHEMA_TOOLS_CONFIG, validate_config
  15. print("\n1. 测试配置验证...")
  16. try:
  17. validate_config()
  18. print("✅ 配置验证通过")
  19. except Exception as e:
  20. print(f"❌ 配置验证失败: {e}")
  21. return
  22. # 2. 测试工具注册
  23. from schema_tools.tools import ToolRegistry
  24. print("\n2. 已注册的工具:")
  25. tools = ToolRegistry.list_tools()
  26. for tool in tools:
  27. print(f" - {tool}")
  28. # 3. 创建测试表清单文件
  29. test_tables_file = "test_tables.txt"
  30. with open(test_tables_file, 'w', encoding='utf-8') as f:
  31. f.write("# 测试表清单\n")
  32. f.write("public.users\n")
  33. f.write("public.orders\n")
  34. f.write("hr.employees\n")
  35. print(f"\n3. 创建测试表清单文件: {test_tables_file}")
  36. # 4. 测试权限检查(仅模拟)
  37. print("\n4. 测试数据库权限检查...")
  38. # 这里需要真实的数据库连接字符串
  39. # 从环境变量或app_config获取
  40. try:
  41. import app_config
  42. if hasattr(app_config, 'PGVECTOR_CONFIG'):
  43. pg_config = app_config.PGVECTOR_CONFIG
  44. db_connection = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
  45. print(f"使用PgVector数据库配置")
  46. else:
  47. print("⚠️ 未找到数据库配置,跳过权限测试")
  48. db_connection = None
  49. except:
  50. print("⚠️ 无法导入app_config,跳过权限测试")
  51. db_connection = None
  52. if db_connection:
  53. from schema_tools.training_data_agent import SchemaTrainingDataAgent
  54. try:
  55. agent = SchemaTrainingDataAgent(
  56. db_connection=db_connection,
  57. table_list_file=test_tables_file,
  58. business_context="测试业务系统"
  59. )
  60. permissions = await agent.check_database_permissions()
  61. print(f"数据库权限: {permissions}")
  62. except Exception as e:
  63. print(f"❌ 权限检查失败: {e}")
  64. # 清理测试文件
  65. if os.path.exists(test_tables_file):
  66. os.remove(test_tables_file)
  67. print("\n===== 测试完成 =====")
  68. async def test_table_parser():
  69. """测试表清单解析器"""
  70. print("\n===== 测试表清单解析器 =====")
  71. from schema_tools.utils.table_parser import TableListParser
  72. parser = TableListParser()
  73. # 测试字符串解析
  74. test_cases = [
  75. "public.users",
  76. "hr.employees,sales.orders",
  77. "users\norders\nproducts",
  78. "schema.table_name"
  79. ]
  80. for test_str in test_cases:
  81. result = parser.parse_string(test_str)
  82. print(f"输入: {repr(test_str)}")
  83. print(f"结果: {result}")
  84. print()
  85. async def test_system_filter():
  86. """测试系统表过滤器"""
  87. print("\n===== 测试系统表过滤器 =====")
  88. from schema_tools.utils.system_filter import SystemTableFilter
  89. filter = SystemTableFilter()
  90. test_tables = [
  91. "pg_class",
  92. "information_schema.tables",
  93. "public.users",
  94. "hr.employees",
  95. "pg_temp_1.temp_table",
  96. "my_table"
  97. ]
  98. for table in test_tables:
  99. if '.' in table:
  100. schema, name = table.split('.', 1)
  101. else:
  102. schema, name = 'public', table
  103. is_system = filter.is_system_table(schema, name)
  104. print(f"{table}: {'系统表' if is_system else '用户表'}")
  105. if __name__ == "__main__":
  106. print("Schema Tools 测试脚本\n")
  107. # 运行测试
  108. asyncio.run(test_basic_functionality())
  109. asyncio.run(test_table_parser())
  110. asyncio.run(test_system_filter())