vanna_trainer.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # vanna_trainer.py
  2. import os
  3. import time
  4. import threading
  5. import queue
  6. import concurrent.futures
  7. from functools import lru_cache
  8. from collections import defaultdict
  9. from typing import List, Dict, Any, Tuple, Optional, Union, Callable
  10. import sys
  11. import os
  12. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  13. import app_config
  14. from core.logging import get_data_pipeline_logger
  15. # 初始化日志
  16. logger = get_data_pipeline_logger("VannaTrainer")
  17. # 设置正确的项目根目录路径
  18. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  19. # 创建vanna实例
  20. from core.vanna_llm_factory import create_vanna_instance
  21. vn = create_vanna_instance()
  22. # 使用新的配置工具函数获取embedding配置
  23. try:
  24. from common.utils import get_current_embedding_config, get_current_model_info
  25. embedding_config = get_current_embedding_config()
  26. model_info = get_current_model_info()
  27. logger.info("===== Embedding模型信息 =====")
  28. logger.info(f"模型类型: {model_info['embedding_type']}")
  29. logger.info(f"模型名称: {model_info['embedding_model']}")
  30. logger.info(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
  31. if 'base_url' in embedding_config:
  32. logger.info(f"API服务: {embedding_config['base_url']}")
  33. logger.info("==============================")
  34. except ImportError as e:
  35. logger.warning(f"无法导入配置工具函数: {e}")
  36. logger.info("使用默认配置...")
  37. embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
  38. logger.info("===== Embedding模型信息 (默认) =====")
  39. logger.info(f"模型名称: {embedding_config.get('model_name', '未知')}")
  40. logger.info("==============================")
  41. # 从app_config获取训练批处理配置
  42. BATCH_PROCESSING_ENABLED = app_config.TRAINING_BATCH_PROCESSING_ENABLED
  43. BATCH_SIZE = app_config.TRAINING_BATCH_SIZE
  44. MAX_WORKERS = app_config.TRAINING_MAX_WORKERS
  45. # 训练数据批处理器
  46. # 专门用于优化训练过程的批处理器,将多个训练项目打包处理以提高效率
  47. class BatchProcessor:
  48. def __init__(self, batch_size=BATCH_SIZE, max_workers=MAX_WORKERS):
  49. self.batch_size = batch_size
  50. self.max_workers = max_workers
  51. self.batches = defaultdict(list)
  52. self.lock = threading.Lock() # 线程安全锁
  53. # 初始化工作线程池
  54. self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
  55. # 是否启用批处理
  56. self.batch_enabled = BATCH_PROCESSING_ENABLED
  57. logger.debug(f"训练批处理器初始化: 启用={self.batch_enabled}, 批大小={self.batch_size}, 最大工作线程={self.max_workers}")
  58. def add_item(self, batch_type: str, item: Dict[str, Any]):
  59. """添加一个项目到批处理队列"""
  60. if not self.batch_enabled:
  61. # 如果未启用批处理,直接处理
  62. self._process_single_item(batch_type, item)
  63. return
  64. with self.lock:
  65. self.batches[batch_type].append(item)
  66. if len(self.batches[batch_type]) >= self.batch_size:
  67. batch_items = self.batches[batch_type]
  68. self.batches[batch_type] = []
  69. # 提交批处理任务到线程池
  70. self.executor.submit(self._process_batch, batch_type, batch_items)
  71. def _process_single_item(self, batch_type: str, item: Dict[str, Any]):
  72. """处理单个项目"""
  73. try:
  74. if batch_type == 'ddl':
  75. vn.train(ddl=item['ddl'])
  76. elif batch_type == 'documentation':
  77. vn.train(documentation=item['documentation'])
  78. elif batch_type == 'question_sql':
  79. vn.train(question=item['question'], sql=item['sql'])
  80. logger.debug(f"单项处理成功: {batch_type}")
  81. except Exception as e:
  82. logger.error(f"处理 {batch_type} 项目失败: {e}")
  83. def _process_batch(self, batch_type: str, items: List[Dict[str, Any]]):
  84. """处理一批项目"""
  85. logger.info(f"开始批量处理 {len(items)} 个 {batch_type} 项")
  86. start_time = time.time()
  87. try:
  88. # 准备批处理数据
  89. batch_data = []
  90. if batch_type == 'ddl':
  91. for item in items:
  92. batch_data.append({
  93. 'type': 'ddl',
  94. 'content': item['ddl']
  95. })
  96. elif batch_type == 'documentation':
  97. for item in items:
  98. batch_data.append({
  99. 'type': 'documentation',
  100. 'content': item['documentation']
  101. })
  102. elif batch_type == 'question_sql':
  103. for item in items:
  104. batch_data.append({
  105. 'type': 'question_sql',
  106. 'question': item['question'],
  107. 'sql': item['sql']
  108. })
  109. # 使用批量添加方法
  110. if hasattr(vn, 'add_batch') and callable(getattr(vn, 'add_batch')):
  111. success = vn.add_batch(batch_data)
  112. if success:
  113. logger.info(f"批量处理成功: {len(items)} 个 {batch_type} 项")
  114. else:
  115. logger.warning(f"批量处理部分失败: {batch_type}")
  116. else:
  117. # 如果没有批处理方法,退回到逐条处理
  118. logger.warning(f"批处理不可用,使用逐条处理: {batch_type}")
  119. for item in items:
  120. self._process_single_item(batch_type, item)
  121. except Exception as e:
  122. logger.error(f"批处理 {batch_type} 失败: {e}")
  123. # 如果批处理失败,尝试逐条处理
  124. logger.info(f"尝试逐条处理...")
  125. for item in items:
  126. try:
  127. self._process_single_item(batch_type, item)
  128. except Exception as item_e:
  129. logger.error(f"处理项目失败: {item_e}")
  130. elapsed = time.time() - start_time
  131. logger.info(f"批处理完成 {len(items)} 个 {batch_type} 项,耗时 {elapsed:.2f} 秒")
  132. def flush_all(self):
  133. """强制处理所有剩余项目"""
  134. with self.lock:
  135. for batch_type, items in self.batches.items():
  136. if items:
  137. logger.info(f"正在处理剩余的 {len(items)} 个 {batch_type} 项")
  138. self._process_batch(batch_type, items)
  139. # 清空队列
  140. self.batches = defaultdict(list)
  141. logger.info("所有训练批处理项目已完成")
  142. def shutdown(self):
  143. """关闭处理器和线程池"""
  144. self.flush_all()
  145. self.executor.shutdown(wait=True)
  146. logger.info("训练批处理器已关闭")
  147. # 创建全局训练批处理器实例
  148. # 用于所有训练函数的批处理优化
  149. batch_processor = BatchProcessor()
  150. # 原始训练函数的批处理增强版本
  151. def train_ddl(ddl_sql: str):
  152. logger.debug(f"Training on DDL:\n{ddl_sql}")
  153. batch_processor.add_item('ddl', {'ddl': ddl_sql})
  154. def train_documentation(doc: str):
  155. logger.debug(f"Training on documentation:\n{doc}")
  156. batch_processor.add_item('documentation', {'documentation': doc})
  157. def train_sql_example(sql: str):
  158. """训练单个SQL示例,通过SQL生成相应的问题"""
  159. logger.debug(f"Training on SQL:\n{sql}")
  160. try:
  161. # 直接调用generate_question方法
  162. question = vn.generate_question(sql=sql)
  163. question = question.strip()
  164. if not question.endswith("?") and not question.endswith("?"):
  165. question += "?"
  166. except Exception as e:
  167. logger.error(f"生成问题时出错: {e}")
  168. raise Exception(f"无法为SQL生成问题: {e}")
  169. logger.debug(f"生成问题: {question}")
  170. # 使用标准方式存储问题-SQL对
  171. batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
  172. def train_question_sql_pair(question: str, sql: str):
  173. logger.debug(f"Training on question-sql pair:\nquestion: {question}\nsql: {sql}")
  174. batch_processor.add_item('question_sql', {'question': question, 'sql': sql})
  175. # 完成训练后刷新所有待处理项
  176. def flush_training():
  177. """强制处理所有待处理的训练项目"""
  178. batch_processor.flush_all()
  179. # 关闭训练器
  180. def shutdown_trainer():
  181. """关闭训练器和相关资源"""
  182. batch_processor.shutdown()