123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- #!/usr/bin/env python3
- """
- 测试training目录的代码集成
- 验证训练相关的模块是否能正常工作
- """
- def test_training_imports():
- """测试训练模块的导入"""
- print("=== 测试训练模块导入 ===")
-
- try:
- # 测试从training包导入
- from training import (
- train_ddl,
- train_documentation,
- train_sql_example,
- train_question_sql_pair,
- flush_training,
- shutdown_trainer
- )
- print("✅ 成功从training包导入所有函数")
-
- # 测试直接导入
- from training.vanna_trainer import BatchProcessor
- print("✅ 成功导入BatchProcessor类")
-
- return True
-
- except ImportError as e:
- print(f"❌ 导入失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_config_access():
- """测试配置访问"""
- print("\n=== 测试配置访问 ===")
-
- try:
- import app_config
-
- # 测试训练批处理配置
- batch_enabled = getattr(app_config, 'TRAINING_BATCH_PROCESSING_ENABLED', None)
- batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', None)
- max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', None)
-
- print(f"批处理启用: {batch_enabled}")
- print(f"批处理大小: {batch_size}")
- print(f"最大工作线程: {max_workers}")
-
- if batch_enabled is not None and batch_size is not None and max_workers is not None:
- print("✅ 训练批处理配置正常")
- else:
- print("⚠️ 部分训练批处理配置缺失")
-
- # 测试向量数据库配置
- vector_db_name = getattr(app_config, 'VECTOR_DB_NAME', None)
- print(f"向量数据库类型: {vector_db_name}")
-
- if vector_db_name == "pgvector":
- pgvector_config = getattr(app_config, 'PGVECTOR_CONFIG', None)
- if pgvector_config:
- print("✅ PgVector配置存在")
- else:
- print("❌ PgVector配置缺失")
-
- # 测试新的配置工具函数
- try:
- from common.utils import get_current_embedding_config, get_current_model_info
-
- embedding_config = get_current_embedding_config()
- model_info = get_current_model_info()
-
- print(f"当前embedding类型: {model_info['embedding_type']}")
- print(f"当前embedding模型: {model_info['embedding_model']}")
- print("✅ 新配置工具函数正常工作")
-
- except Exception as e:
- print(f"⚠️ 新配置工具函数测试失败: {e}")
-
- return True
-
- except Exception as e:
- print(f"❌ 配置访问测试失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_vanna_instance_creation():
- """测试Vanna实例创建"""
- print("\n=== 测试Vanna实例创建 ===")
-
- try:
- from vanna_llm_factory import create_vanna_instance
-
- print("尝试创建Vanna实例...")
- vn = create_vanna_instance()
-
- print(f"✅ 成功创建Vanna实例: {type(vn).__name__}")
-
- # 测试基本方法是否存在
- required_methods = ['train', 'generate_question', 'get_training_data']
- for method in required_methods:
- if hasattr(vn, method):
- print(f"✅ 方法 {method} 存在")
- else:
- print(f"⚠️ 方法 {method} 不存在")
-
- return True
-
- except Exception as e:
- print(f"❌ Vanna实例创建失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_batch_processor():
- """测试批处理器"""
- print("\n=== 测试批处理器 ===")
-
- try:
- from training.vanna_trainer import BatchProcessor
- import app_config
-
- # 创建测试批处理器
- batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', 5)
- max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', 2)
-
- processor = BatchProcessor(batch_size=batch_size, max_workers=max_workers)
- print(f"✅ 成功创建BatchProcessor实例")
- print(f" 批处理大小: {processor.batch_size}")
- print(f" 最大工作线程: {processor.max_workers}")
- print(f" 批处理启用: {processor.batch_enabled}")
-
- # 测试关闭
- processor.shutdown()
- print("✅ 批处理器关闭成功")
-
- return True
-
- except Exception as e:
- print(f"❌ 批处理器测试失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_training_functions():
- """测试训练函数(不实际训练)"""
- print("\n=== 测试训练函数 ===")
-
- try:
- from training import (
- train_ddl,
- train_documentation,
- train_sql_example,
- train_question_sql_pair,
- flush_training,
- shutdown_trainer
- )
-
- print("✅ 所有训练函数导入成功")
-
- # 测试函数是否可调用
- functions_to_test = [
- ('train_ddl', train_ddl),
- ('train_documentation', train_documentation),
- ('train_sql_example', train_sql_example),
- ('train_question_sql_pair', train_question_sql_pair),
- ('flush_training', flush_training),
- ('shutdown_trainer', shutdown_trainer)
- ]
-
- for func_name, func in functions_to_test:
- if callable(func):
- print(f"✅ {func_name} 是可调用的")
- else:
- print(f"❌ {func_name} 不可调用")
-
- return True
-
- except Exception as e:
- print(f"❌ 训练函数测试失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_embedding_connection():
- """测试embedding连接"""
- print("\n=== 测试Embedding连接 ===")
-
- try:
- from embedding_function import test_embedding_connection
-
- print("测试embedding模型连接...")
- result = test_embedding_connection()
-
- if result["success"]:
- print(f"✅ Embedding连接成功: {result['message']}")
- else:
- print(f"⚠️ Embedding连接失败: {result['message']}")
- print(" 这可能是因为API服务未启动或配置不正确")
-
- return True
-
- except Exception as e:
- print(f"❌ Embedding连接测试失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def test_run_training_script():
- """测试run_training.py脚本的基本功能"""
- print("\n=== 测试run_training.py脚本 ===")
-
- try:
- # 导入run_training模块
- import sys
- import os
-
- # 添加training目录到路径
- training_dir = os.path.join(os.path.dirname(__file__), 'training')
- if training_dir not in sys.path:
- sys.path.insert(0, training_dir)
-
- # 导入run_training模块的函数
- from training.run_training import (
- read_file_by_delimiter,
- read_markdown_file_by_sections,
- check_pgvector_connection
- )
-
- print("✅ 成功导入run_training模块的函数")
-
- # 测试文件读取函数
- test_content = "section1---section2---section3"
- with open("test_temp.txt", "w", encoding="utf-8") as f:
- f.write(test_content)
-
- try:
- sections = read_file_by_delimiter("test_temp.txt", "---")
- if len(sections) == 3:
- print("✅ read_file_by_delimiter 函数正常工作")
- else:
- print(f"⚠️ read_file_by_delimiter 返回了 {len(sections)} 个部分,期望 3 个")
- finally:
- if os.path.exists("test_temp.txt"):
- os.remove("test_temp.txt")
-
- return True
-
- except Exception as e:
- print(f"❌ run_training.py脚本测试失败: {e}")
- import traceback
- traceback.print_exc()
- return False
- def main():
- """主测试函数"""
- print("开始测试training目录的代码集成...")
- print("=" * 60)
-
- results = []
-
- # 运行所有测试
- results.append(("训练模块导入", test_training_imports()))
- results.append(("配置访问", test_config_access()))
- results.append(("Vanna实例创建", test_vanna_instance_creation()))
- results.append(("批处理器", test_batch_processor()))
- results.append(("训练函数", test_training_functions()))
- results.append(("Embedding连接", test_embedding_connection()))
- results.append(("run_training脚本", test_run_training_script()))
-
- # 总结
- print(f"\n{'='*60}")
- print("测试结果总结:")
- print("=" * 60)
-
- for test_name, success in results:
- status = "✅ 通过" if success else "❌ 失败"
- print(f"{test_name}: {status}")
-
- total_passed = sum(1 for _, success in results if success)
- print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
-
- if total_passed == len(results):
- print("🎉 所有测试都通过了!training目录的代码可以正常工作。")
- elif total_passed >= len(results) - 1:
- print("✅ 大部分测试通过,training目录的代码基本可以正常工作。")
- print(" 部分失败可能是由于服务未启动或配置问题。")
- else:
- print("⚠️ 多个测试失败,请检查相关依赖和配置。")
- if __name__ == "__main__":
- main()
|