test_training_integration.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #!/usr/bin/env python3
  2. """
  3. 测试training目录的代码集成
  4. 验证训练相关的模块是否能正常工作
  5. """
  6. def test_training_imports():
  7. """测试训练模块的导入"""
  8. print("=== 测试训练模块导入 ===")
  9. try:
  10. # 测试从training包导入
  11. from training import (
  12. train_ddl,
  13. train_documentation,
  14. train_sql_example,
  15. train_question_sql_pair,
  16. flush_training,
  17. shutdown_trainer
  18. )
  19. print("✅ 成功从training包导入所有函数")
  20. # 测试直接导入
  21. from training.vanna_trainer import BatchProcessor
  22. print("✅ 成功导入BatchProcessor类")
  23. return True
  24. except ImportError as e:
  25. print(f"❌ 导入失败: {e}")
  26. import traceback
  27. traceback.print_exc()
  28. return False
  29. def test_config_access():
  30. """测试配置访问"""
  31. print("\n=== 测试配置访问 ===")
  32. try:
  33. import app_config
  34. # 测试训练批处理配置
  35. batch_enabled = getattr(app_config, 'TRAINING_BATCH_PROCESSING_ENABLED', None)
  36. batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', None)
  37. max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', None)
  38. print(f"批处理启用: {batch_enabled}")
  39. print(f"批处理大小: {batch_size}")
  40. print(f"最大工作线程: {max_workers}")
  41. if batch_enabled is not None and batch_size is not None and max_workers is not None:
  42. print("✅ 训练批处理配置正常")
  43. else:
  44. print("⚠️ 部分训练批处理配置缺失")
  45. # 测试向量数据库配置
  46. vector_db_name = getattr(app_config, 'VECTOR_DB_NAME', None)
  47. print(f"向量数据库类型: {vector_db_name}")
  48. if vector_db_name == "pgvector":
  49. pgvector_config = getattr(app_config, 'PGVECTOR_CONFIG', None)
  50. if pgvector_config:
  51. print("✅ PgVector配置存在")
  52. else:
  53. print("❌ PgVector配置缺失")
  54. # 测试新的配置工具函数
  55. try:
  56. from common.utils import get_current_embedding_config, get_current_model_info
  57. embedding_config = get_current_embedding_config()
  58. model_info = get_current_model_info()
  59. print(f"当前embedding类型: {model_info['embedding_type']}")
  60. print(f"当前embedding模型: {model_info['embedding_model']}")
  61. print("✅ 新配置工具函数正常工作")
  62. except Exception as e:
  63. print(f"⚠️ 新配置工具函数测试失败: {e}")
  64. return True
  65. except Exception as e:
  66. print(f"❌ 配置访问测试失败: {e}")
  67. import traceback
  68. traceback.print_exc()
  69. return False
  70. def test_vanna_instance_creation():
  71. """测试Vanna实例创建"""
  72. print("\n=== 测试Vanna实例创建 ===")
  73. try:
  74. from vanna_llm_factory import create_vanna_instance
  75. print("尝试创建Vanna实例...")
  76. vn = create_vanna_instance()
  77. print(f"✅ 成功创建Vanna实例: {type(vn).__name__}")
  78. # 测试基本方法是否存在
  79. required_methods = ['train', 'generate_question', 'get_training_data']
  80. for method in required_methods:
  81. if hasattr(vn, method):
  82. print(f"✅ 方法 {method} 存在")
  83. else:
  84. print(f"⚠️ 方法 {method} 不存在")
  85. return True
  86. except Exception as e:
  87. print(f"❌ Vanna实例创建失败: {e}")
  88. import traceback
  89. traceback.print_exc()
  90. return False
  91. def test_batch_processor():
  92. """测试批处理器"""
  93. print("\n=== 测试批处理器 ===")
  94. try:
  95. from training.vanna_trainer import BatchProcessor
  96. import app_config
  97. # 创建测试批处理器
  98. batch_size = getattr(app_config, 'TRAINING_BATCH_SIZE', 5)
  99. max_workers = getattr(app_config, 'TRAINING_MAX_WORKERS', 2)
  100. processor = BatchProcessor(batch_size=batch_size, max_workers=max_workers)
  101. print(f"✅ 成功创建BatchProcessor实例")
  102. print(f" 批处理大小: {processor.batch_size}")
  103. print(f" 最大工作线程: {processor.max_workers}")
  104. print(f" 批处理启用: {processor.batch_enabled}")
  105. # 测试关闭
  106. processor.shutdown()
  107. print("✅ 批处理器关闭成功")
  108. return True
  109. except Exception as e:
  110. print(f"❌ 批处理器测试失败: {e}")
  111. import traceback
  112. traceback.print_exc()
  113. return False
  114. def test_training_functions():
  115. """测试训练函数(不实际训练)"""
  116. print("\n=== 测试训练函数 ===")
  117. try:
  118. from training import (
  119. train_ddl,
  120. train_documentation,
  121. train_sql_example,
  122. train_question_sql_pair,
  123. flush_training,
  124. shutdown_trainer
  125. )
  126. print("✅ 所有训练函数导入成功")
  127. # 测试函数是否可调用
  128. functions_to_test = [
  129. ('train_ddl', train_ddl),
  130. ('train_documentation', train_documentation),
  131. ('train_sql_example', train_sql_example),
  132. ('train_question_sql_pair', train_question_sql_pair),
  133. ('flush_training', flush_training),
  134. ('shutdown_trainer', shutdown_trainer)
  135. ]
  136. for func_name, func in functions_to_test:
  137. if callable(func):
  138. print(f"✅ {func_name} 是可调用的")
  139. else:
  140. print(f"❌ {func_name} 不可调用")
  141. return True
  142. except Exception as e:
  143. print(f"❌ 训练函数测试失败: {e}")
  144. import traceback
  145. traceback.print_exc()
  146. return False
  147. def test_embedding_connection():
  148. """测试embedding连接"""
  149. print("\n=== 测试Embedding连接 ===")
  150. try:
  151. from embedding_function import test_embedding_connection
  152. print("测试embedding模型连接...")
  153. result = test_embedding_connection()
  154. if result["success"]:
  155. print(f"✅ Embedding连接成功: {result['message']}")
  156. else:
  157. print(f"⚠️ Embedding连接失败: {result['message']}")
  158. print(" 这可能是因为API服务未启动或配置不正确")
  159. return True
  160. except Exception as e:
  161. print(f"❌ Embedding连接测试失败: {e}")
  162. import traceback
  163. traceback.print_exc()
  164. return False
  165. def test_run_training_script():
  166. """测试run_training.py脚本的基本功能"""
  167. print("\n=== 测试run_training.py脚本 ===")
  168. try:
  169. # 导入run_training模块
  170. import sys
  171. import os
  172. # 添加training目录到路径
  173. training_dir = os.path.join(os.path.dirname(__file__), 'training')
  174. if training_dir not in sys.path:
  175. sys.path.insert(0, training_dir)
  176. # 导入run_training模块的函数
  177. from training.run_training import (
  178. read_file_by_delimiter,
  179. read_markdown_file_by_sections,
  180. check_pgvector_connection
  181. )
  182. print("✅ 成功导入run_training模块的函数")
  183. # 测试文件读取函数
  184. test_content = "section1---section2---section3"
  185. with open("test_temp.txt", "w", encoding="utf-8") as f:
  186. f.write(test_content)
  187. try:
  188. sections = read_file_by_delimiter("test_temp.txt", "---")
  189. if len(sections) == 3:
  190. print("✅ read_file_by_delimiter 函数正常工作")
  191. else:
  192. print(f"⚠️ read_file_by_delimiter 返回了 {len(sections)} 个部分,期望 3 个")
  193. finally:
  194. if os.path.exists("test_temp.txt"):
  195. os.remove("test_temp.txt")
  196. return True
  197. except Exception as e:
  198. print(f"❌ run_training.py脚本测试失败: {e}")
  199. import traceback
  200. traceback.print_exc()
  201. return False
  202. def main():
  203. """主测试函数"""
  204. print("开始测试training目录的代码集成...")
  205. print("=" * 60)
  206. results = []
  207. # 运行所有测试
  208. results.append(("训练模块导入", test_training_imports()))
  209. results.append(("配置访问", test_config_access()))
  210. results.append(("Vanna实例创建", test_vanna_instance_creation()))
  211. results.append(("批处理器", test_batch_processor()))
  212. results.append(("训练函数", test_training_functions()))
  213. results.append(("Embedding连接", test_embedding_connection()))
  214. results.append(("run_training脚本", test_run_training_script()))
  215. # 总结
  216. print(f"\n{'='*60}")
  217. print("测试结果总结:")
  218. print("=" * 60)
  219. for test_name, success in results:
  220. status = "✅ 通过" if success else "❌ 失败"
  221. print(f"{test_name}: {status}")
  222. total_passed = sum(1 for _, success in results if success)
  223. print(f"\n总计: {total_passed}/{len(results)} 个测试通过")
  224. if total_passed == len(results):
  225. print("🎉 所有测试都通过了!training目录的代码可以正常工作。")
  226. elif total_passed >= len(results) - 1:
  227. print("✅ 大部分测试通过,training目录的代码基本可以正常工作。")
  228. print(" 部分失败可能是由于服务未启动或配置问题。")
  229. else:
  230. print("⚠️ 多个测试失败,请检查相关依赖和配置。")
  231. if __name__ == "__main__":
  232. main()