run_training.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818
  1. # run_training.py
  2. import os
  3. import time
  4. import re
  5. import json
  6. import sys
  7. import requests
  8. import pandas as pd
  9. import argparse
  10. from pathlib import Path
  11. from sqlalchemy import create_engine
  12. from .vanna_trainer import (
  13. train_ddl,
  14. train_documentation,
  15. train_sql_example,
  16. train_question_sql_pair,
  17. flush_training,
  18. shutdown_trainer
  19. )
  20. def check_embedding_model_connection():
  21. """检查嵌入模型连接是否可用
  22. 如果无法连接到嵌入模型,则终止程序执行
  23. Returns:
  24. bool: 连接成功返回True,否则终止程序
  25. """
  26. from core.embedding_function import test_embedding_connection
  27. print("正在检查嵌入模型连接...")
  28. # 使用专门的测试函数进行连接测试
  29. test_result = test_embedding_connection()
  30. if test_result["success"]:
  31. print(f"可以继续训练过程。")
  32. return True
  33. else:
  34. print(f"\n错误: 无法连接到嵌入模型: {test_result['message']}")
  35. print("训练过程终止。请检查配置和API服务可用性。")
  36. sys.exit(1)
  37. def read_file_by_delimiter(filepath, delimiter="---"):
  38. """通用读取:将文件按分隔符切片为多个段落"""
  39. with open(filepath, "r", encoding="utf-8") as f:
  40. content = f.read()
  41. blocks = [block.strip() for block in content.split(delimiter) if block.strip()]
  42. return blocks
  43. def read_markdown_file_by_sections(filepath):
  44. """专门用于Markdown文件:按标题(#、##、###)分割文档
  45. Args:
  46. filepath (str): Markdown文件路径
  47. Returns:
  48. list: 分割后的Markdown章节列表
  49. """
  50. with open(filepath, "r", encoding="utf-8") as f:
  51. content = f.read()
  52. # 确定文件是否为Markdown
  53. is_markdown = filepath.lower().endswith('.md') or filepath.lower().endswith('.markdown')
  54. if not is_markdown:
  55. # 非Markdown文件使用默认的---分隔
  56. return read_file_by_delimiter(filepath, "---")
  57. # 直接按照标题级别分割内容,处理#、##和###
  58. sections = []
  59. # 匹配所有级别的标题(#、##或###开头)
  60. header_pattern = r'(?:^|\n)((?:#|##|###)[^#].*?)(?=\n(?:#|##|###)[^#]|\Z)'
  61. all_sections = re.findall(header_pattern, content, re.DOTALL)
  62. for section in all_sections:
  63. section = section.strip()
  64. if section:
  65. sections.append(section)
  66. # 处理没有匹配到标题的情况
  67. if not sections and content.strip():
  68. sections = [content.strip()]
  69. return sections
  70. def train_ddl_statements(ddl_file):
  71. """训练DDL语句
  72. Args:
  73. ddl_file (str): DDL文件路径
  74. """
  75. print(f"开始训练 DDL: {ddl_file}")
  76. if not os.path.exists(ddl_file):
  77. print(f"DDL 文件不存在: {ddl_file}")
  78. return
  79. for idx, ddl in enumerate(read_file_by_delimiter(ddl_file, ";"), start=1):
  80. try:
  81. print(f"\n DDL 训练 {idx}")
  82. train_ddl(ddl)
  83. except Exception as e:
  84. print(f"错误:DDL #{idx} - {e}")
  85. def train_documentation_blocks(doc_file):
  86. """训练文档块
  87. Args:
  88. doc_file (str): 文档文件路径
  89. """
  90. print(f"开始训练 文档: {doc_file}")
  91. if not os.path.exists(doc_file):
  92. print(f"文档文件不存在: {doc_file}")
  93. return
  94. # 检查是否为Markdown文件
  95. is_markdown = doc_file.lower().endswith('.md') or doc_file.lower().endswith('.markdown')
  96. if is_markdown:
  97. # 使用Markdown专用分割器
  98. sections = read_markdown_file_by_sections(doc_file)
  99. print(f" Markdown文档已分割为 {len(sections)} 个章节")
  100. for idx, section in enumerate(sections, start=1):
  101. try:
  102. section_title = section.split('\n', 1)[0].strip()
  103. print(f"\n Markdown章节训练 {idx}: {section_title}")
  104. # 检查部分长度并提供警告
  105. if len(section) > 2000:
  106. print(f" 章节 {idx} 长度为 {len(section)} 字符,接近API限制(2048)")
  107. train_documentation(section)
  108. except Exception as e:
  109. print(f" 错误:章节 #{idx} - {e}")
  110. else:
  111. # 非Markdown文件使用传统的---分隔
  112. for idx, doc in enumerate(read_file_by_delimiter(doc_file, "---"), start=1):
  113. try:
  114. print(f"\n 文档训练 {idx}")
  115. train_documentation(doc)
  116. except Exception as e:
  117. print(f" 错误:文档 #{idx} - {e}")
  118. def train_sql_examples(sql_file):
  119. """训练SQL示例
  120. Args:
  121. sql_file (str): SQL示例文件路径
  122. """
  123. print(f" 开始训练 SQL 示例: {sql_file}")
  124. if not os.path.exists(sql_file):
  125. print(f" SQL 示例文件不存在: {sql_file}")
  126. return
  127. for idx, sql in enumerate(read_file_by_delimiter(sql_file, ";"), start=1):
  128. try:
  129. print(f"\n SQL 示例训练 {idx}")
  130. train_sql_example(sql)
  131. except Exception as e:
  132. print(f" 错误:SQL #{idx} - {e}")
  133. def train_question_sql_pairs(qs_file):
  134. """训练问答对
  135. Args:
  136. qs_file (str): 问答对文件路径
  137. """
  138. print(f" 开始训练 问答对: {qs_file}")
  139. if not os.path.exists(qs_file):
  140. print(f" 问答文件不存在: {qs_file}")
  141. return
  142. try:
  143. with open(qs_file, "r", encoding="utf-8") as f:
  144. lines = f.readlines()
  145. for idx, line in enumerate(lines, start=1):
  146. if "::" not in line:
  147. continue
  148. question, sql = line.strip().split("::", 1)
  149. print(f"\n 问答训练 {idx}")
  150. train_question_sql_pair(question.strip(), sql.strip())
  151. except Exception as e:
  152. print(f" 错误:问答训练 - {e}")
  153. def train_formatted_question_sql_pairs(formatted_file):
  154. """训练格式化的问答对文件
  155. 支持两种格式:
  156. 1. Question: xxx\nSQL: xxx (单行SQL)
  157. 2. Question: xxx\nSQL:\nxxx\nxxx (多行SQL)
  158. Args:
  159. formatted_file (str): 格式化问答对文件路径
  160. """
  161. print(f" 开始训练 格式化问答对: {formatted_file}")
  162. if not os.path.exists(formatted_file):
  163. print(f" 格式化问答文件不存在: {formatted_file}")
  164. return
  165. # 读取整个文件内容
  166. with open(formatted_file, "r", encoding="utf-8") as f:
  167. content = f.read()
  168. # 按双空行分割不同的问答对
  169. # 使用更精确的分隔符,避免误识别
  170. pairs = []
  171. # 使用大小写不敏感的正则表达式来分割
  172. import re
  173. blocks = re.split(r'\n\n(?=question\s*:)', content, flags=re.IGNORECASE)
  174. # 处理第一块(可能没有前导的"\n\nQuestion:")
  175. first_block = blocks[0]
  176. if re.search(r'^\s*question\s*:', first_block.strip(), re.IGNORECASE):
  177. pairs.append(first_block.strip())
  178. elif re.search(r'question\s*:', first_block, re.IGNORECASE):
  179. # 处理文件开头没有Question:的情况
  180. question_match = re.search(r'question\s*:', first_block, re.IGNORECASE)
  181. pairs.append(first_block[question_match.start():].strip())
  182. # 处理其余块
  183. for block in blocks[1:]:
  184. pairs.append(block.strip())
  185. # 处理每个问答对
  186. successfully_processed = 0
  187. for idx, pair in enumerate(pairs, start=1):
  188. try:
  189. # 使用大小写不敏感的匹配
  190. question_match = re.search(r'question\s*:', pair, re.IGNORECASE)
  191. sql_match = re.search(r'sql\s*:', pair, re.IGNORECASE)
  192. if not question_match or not sql_match:
  193. print(f" 跳过不符合格式的对 #{idx}")
  194. continue
  195. # 确保SQL在Question之后
  196. if sql_match.start() <= question_match.end():
  197. print(f" SQL部分未找到,跳过对 #{idx}")
  198. continue
  199. # 提取问题部分
  200. question_start = question_match.end()
  201. sql_start = sql_match.start()
  202. question = pair[question_start:sql_start].strip()
  203. # 提取SQL部分(支持多行)
  204. sql_part = pair[sql_match.end():].strip()
  205. # 检查是否存在下一个Question标记(防止解析错误)
  206. next_question_match = re.search(r'question\s*:', pair[sql_match.end():], re.IGNORECASE)
  207. if next_question_match:
  208. sql_part = pair[sql_match.end():sql_match.end() + next_question_match.start()].strip()
  209. if not question or not sql_part:
  210. print(f" 问题或SQL为空,跳过对 #{idx}")
  211. continue
  212. # 训练问答对
  213. print(f"\n格式化问答训练 {idx}")
  214. print(f"问题: {question}")
  215. print(f"SQL: {sql_part}")
  216. train_question_sql_pair(question, sql_part)
  217. successfully_processed += 1
  218. except Exception as e:
  219. print(f" 错误:格式化问答训练对 #{idx} - {e}")
  220. print(f"格式化问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(pairs)} 对)")
  221. def _is_valid_training_file(filename: str) -> bool:
  222. """判断是否为有效的训练文件"""
  223. import re
  224. filename_lower = filename.lower()
  225. # 排除带数字后缀的文件
  226. if re.search(r'\.(ddl|md)_\d+$', filename_lower):
  227. return False
  228. # 排除 _old 后缀的文件
  229. if filename_lower.endswith('_old'):
  230. return False
  231. # 排除 .backup 相关文件
  232. if '.backup' in filename_lower:
  233. return False
  234. return True
  235. def train_json_question_sql_pairs(json_file):
  236. """训练JSON格式的问答对
  237. Args:
  238. json_file (str): JSON格式问答对文件路径
  239. """
  240. print(f" 开始训练 JSON格式问答对: {json_file}")
  241. if not os.path.exists(json_file):
  242. print(f" JSON问答文件不存在: {json_file}")
  243. return
  244. try:
  245. # 读取JSON文件
  246. with open(json_file, "r", encoding="utf-8") as f:
  247. data = json.load(f)
  248. # 确保数据是列表格式
  249. if not isinstance(data, list):
  250. print(f" 错误: JSON文件格式不正确,应为问答对列表")
  251. return
  252. successfully_processed = 0
  253. for idx, pair in enumerate(data, start=1):
  254. try:
  255. # 检查问答对格式
  256. if not isinstance(pair, dict):
  257. print(f" 跳过不符合格式的对 #{idx}")
  258. continue
  259. # 大小写不敏感地查找question和sql键
  260. question_key = None
  261. sql_key = None
  262. question_value = None
  263. sql_value = None
  264. for key, value in pair.items():
  265. if key.lower() == "question":
  266. question_key = key
  267. question_value = value
  268. elif key.lower() == "sql":
  269. sql_key = key
  270. sql_value = value
  271. if question_key is None or sql_key is None:
  272. print(f" 跳过不符合格式的对 #{idx}")
  273. continue
  274. question = str(question_value).strip()
  275. sql = str(sql_value).strip()
  276. if not question or not sql:
  277. print(f" 问题或SQL为空,跳过对 #{idx}")
  278. continue
  279. # 训练问答对
  280. print(f"\n JSON格式问答训练 {idx}")
  281. print(f"问题: {question}")
  282. print(f"SQL: {sql}")
  283. train_question_sql_pair(question, sql)
  284. successfully_processed += 1
  285. except Exception as e:
  286. print(f" 错误:JSON问答训练对 #{idx} - {e}")
  287. print(f"JSON格式问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(data)} 对)")
  288. except json.JSONDecodeError as e:
  289. print(f" 错误:JSON解析失败 - {e}")
  290. except Exception as e:
  291. print(f" 错误:处理JSON问答训练 - {e}")
  292. def process_training_files(data_path, task_id=None, backup_vector_tables=False, truncate_vector_tables=False, skip_training=False):
  293. """处理指定路径下的所有训练文件
  294. Args:
  295. data_path (str): 训练数据目录路径
  296. task_id (str): 任务ID,用于日志记录
  297. backup_vector_tables (bool): 是否备份vector表数据
  298. truncate_vector_tables (bool): 是否清空vector表数据
  299. skip_training (bool): 是否跳过训练文件处理,仅执行Vector表管理
  300. Returns:
  301. tuple: (处理成功标志, Vector表管理统计信息)
  302. """
  303. # 初始化日志
  304. if task_id:
  305. from data_pipeline.dp_logging import get_logger
  306. logger = get_logger("TrainingDataLoader", task_id)
  307. logger.info(f"扫描训练数据目录: {os.path.abspath(data_path)}")
  308. else:
  309. # 兼容原有调用方式
  310. print(f"\n===== 扫描训练数据目录: {os.path.abspath(data_path)} =====")
  311. logger = None
  312. # 检查目录是否存在
  313. if not os.path.exists(data_path):
  314. error_msg = f"错误: 训练数据目录不存在: {data_path}"
  315. if logger:
  316. logger.error(error_msg)
  317. else:
  318. print(error_msg)
  319. return False
  320. # 日志输出辅助函数
  321. def log_message(message, level="info"):
  322. if logger:
  323. getattr(logger, level)(message)
  324. else:
  325. print(message)
  326. # Vector表管理(前置步骤)
  327. vector_stats = None
  328. if backup_vector_tables or truncate_vector_tables:
  329. # 参数验证和自动启用逻辑
  330. if truncate_vector_tables:
  331. backup_vector_tables = True
  332. try:
  333. import asyncio
  334. from data_pipeline.trainer.vector_table_manager import VectorTableManager
  335. log_message("🗂️ 开始执行Vector表管理...")
  336. vector_manager = VectorTableManager(data_path, task_id)
  337. vector_stats = vector_manager.execute_vector_management(backup_vector_tables, truncate_vector_tables)
  338. log_message("✅ Vector表管理完成")
  339. except Exception as e:
  340. log_message(f"❌ Vector表管理失败: {e}", "error")
  341. return False, None
  342. # 如果是跳过训练模式,跳过训练文件处理
  343. if skip_training:
  344. log_message("✅ Vector表管理完成,跳过训练文件处理(skip_training=True)")
  345. return True, vector_stats
  346. elif skip_training:
  347. # 如果设置了skip_training但没有Vector操作,记录警告并跳过
  348. log_message("⚠️ 设置了skip_training=True但未指定Vector操作,跳过所有处理")
  349. return True, None
  350. # 初始化统计计数器
  351. stats = {
  352. "ddl": 0,
  353. "documentation": 0,
  354. "sql_example": 0,
  355. "question_sql_formatted": 0,
  356. "question_sql_json": 0
  357. }
  358. # 只扫描指定目录下的直接文件,不扫描子目录
  359. try:
  360. items = os.listdir(data_path)
  361. for item in items:
  362. item_path = os.path.join(data_path, item)
  363. # 只处理文件,跳过目录
  364. if not os.path.isfile(item_path):
  365. log_message(f"跳过子目录: {item}")
  366. continue
  367. file_lower = item.lower()
  368. # 根据文件类型调用相应的处理函数
  369. try:
  370. # 检查是否为有效的训练文件
  371. if not _is_valid_training_file(item):
  372. log_message(f"跳过无效训练文件: {item}")
  373. continue
  374. if file_lower.endswith(".ddl"):
  375. log_message(f"处理DDL文件: {item_path}")
  376. train_ddl_statements(item_path)
  377. stats["ddl"] += 1
  378. elif file_lower.endswith(".md") or file_lower.endswith(".markdown"):
  379. log_message(f"处理文档文件: {item_path}")
  380. train_documentation_blocks(item_path)
  381. stats["documentation"] += 1
  382. elif file_lower.endswith("_pair.json") or file_lower.endswith("_pairs.json"):
  383. log_message(f"处理JSON问答对文件: {item_path}")
  384. train_json_question_sql_pairs(item_path)
  385. stats["question_sql_json"] += 1
  386. elif file_lower.endswith("_pair.sql") or file_lower.endswith("_pairs.sql"):
  387. log_message(f"处理格式化问答对文件: {item_path}")
  388. train_formatted_question_sql_pairs(item_path)
  389. stats["question_sql_formatted"] += 1
  390. elif file_lower.endswith(".sql") and not (file_lower.endswith("_pair.sql") or file_lower.endswith("_pairs.sql")):
  391. log_message(f"处理SQL示例文件: {item_path}")
  392. train_sql_examples(item_path)
  393. stats["sql_example"] += 1
  394. else:
  395. log_message(f"跳过不支持的文件类型: {item}")
  396. except Exception as e:
  397. log_message(f"处理文件 {item_path} 时出错: {e}", "error")
  398. except OSError as e:
  399. log_message(f"读取目录失败: {e}", "error")
  400. return False, vector_stats
  401. # 打印处理统计
  402. log_message("训练文件处理统计:")
  403. log_message(f"DDL文件: {stats['ddl']}个")
  404. log_message(f"文档文件: {stats['documentation']}个")
  405. log_message(f"SQL示例文件: {stats['sql_example']}个")
  406. log_message(f"格式化问答对文件: {stats['question_sql_formatted']}个")
  407. log_message(f"JSON问答对文件: {stats['question_sql_json']}个")
  408. total_files = sum(stats.values())
  409. if total_files == 0:
  410. log_message(f"警告: 在目录 {data_path} 中未找到任何可训练的文件", "warning")
  411. return False, vector_stats
  412. return True, vector_stats
  413. def check_pgvector_connection():
  414. """检查 PgVector 数据库连接是否可用
  415. Returns:
  416. bool: 连接成功返回True,否则返回False
  417. """
  418. import app_config
  419. from sqlalchemy import create_engine, text
  420. try:
  421. # 构建连接字符串
  422. pg_config = app_config.PGVECTOR_CONFIG
  423. connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
  424. print(f"正在测试 PgVector 数据库连接...")
  425. print(f"连接地址: {pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}")
  426. # 创建数据库引擎并测试连接
  427. engine = create_engine(connection_string)
  428. with engine.connect() as connection:
  429. # 测试基本连接
  430. result = connection.execute(text("SELECT 1"))
  431. result.fetchone()
  432. # 检查是否安装了 pgvector 扩展
  433. try:
  434. result = connection.execute(text("SELECT extname FROM pg_extension WHERE extname = 'vector'"))
  435. extension_exists = result.fetchone() is not None
  436. if extension_exists:
  437. print("✓ PgVector 扩展已安装")
  438. else:
  439. print("⚠ 警告: PgVector 扩展未安装,请确保已安装 pgvector 扩展")
  440. except Exception as ext_e:
  441. print(f"⚠ 无法检查 pgvector 扩展状态: {ext_e}")
  442. # 检查训练数据表是否存在
  443. try:
  444. result = connection.execute(text("SELECT tablename FROM pg_tables WHERE tablename = 'langchain_pg_embedding'"))
  445. table_exists = result.fetchone() is not None
  446. if table_exists:
  447. # 获取表中的记录数
  448. result = connection.execute(text("SELECT COUNT(*) FROM langchain_pg_embedding"))
  449. count = result.fetchone()[0]
  450. print(f"✓ 训练数据表存在,当前包含 {count} 条记录")
  451. else:
  452. print("ℹ 训练数据表尚未创建(首次训练时会自动创建)")
  453. except Exception as table_e:
  454. print(f"⚠ 无法检查训练数据表状态: {table_e}")
  455. print("✓ PgVector 数据库连接测试成功")
  456. return True
  457. except Exception as e:
  458. print(f"✗ PgVector 数据库连接失败: {e}")
  459. return False
  460. def main():
  461. """主函数:配置和运行训练流程"""
  462. # 先导入所需模块
  463. import os
  464. import app_config
  465. # 解析命令行参数
  466. parser = argparse.ArgumentParser(description='训练Vanna NL2SQL模型')
  467. # 获取默认路径并进行智能处理
  468. def resolve_training_data_path():
  469. """智能解析训练数据路径"""
  470. # 使用data_pipeline统一配置
  471. try:
  472. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  473. config_path = SCHEMA_TOOLS_CONFIG.get("output_directory", './data_pipeline/training_data/')
  474. except ImportError:
  475. # 如果无法导入data_pipeline配置,使用默认路径
  476. config_path = './data_pipeline/training_data/'
  477. # 如果是绝对路径,直接返回
  478. if os.path.isabs(config_path):
  479. return config_path
  480. # 如果以 . 开头,相对于项目根目录解析
  481. if config_path.startswith('./') or config_path.startswith('../'):
  482. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  483. return os.path.join(project_root, config_path)
  484. # 其他情况,相对于项目根目录
  485. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  486. return os.path.join(project_root, config_path)
  487. def resolve_data_path_with_task_id(task_id):
  488. """使用task_id构建训练数据路径"""
  489. # 使用data_pipeline统一配置
  490. try:
  491. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  492. base_dir = SCHEMA_TOOLS_CONFIG.get("output_directory", './data_pipeline/training_data/')
  493. except ImportError:
  494. # 如果无法导入data_pipeline配置,使用默认路径
  495. base_dir = './data_pipeline/training_data/'
  496. # 处理相对路径
  497. from pathlib import Path
  498. if not Path(base_dir).is_absolute():
  499. # 相对于项目根目录解析
  500. project_root = Path(__file__).parent.parent.parent
  501. base_dir = project_root / base_dir
  502. return str(Path(base_dir) / task_id)
  503. default_path = resolve_training_data_path()
  504. # 参数定义
  505. parser.add_argument(
  506. '--task-id',
  507. help='任务ID,指定后将自动构建训练数据目录路径 (基础目录/task_id)'
  508. )
  509. parser.add_argument('--data_path', type=str, default=default_path,
  510. help='训练数据目录路径 (默认: 从data_pipeline.config.SCHEMA_TOOLS_CONFIG)')
  511. parser.add_argument('--backup-vector-tables', action='store_true',
  512. help='备份vector表数据')
  513. parser.add_argument('--truncate-vector-tables', action='store_true',
  514. help='清空vector表数据(自动启用备份)')
  515. parser.add_argument('--skip-training', action='store_true',
  516. help='跳过训练文件处理,仅执行Vector表管理')
  517. args = parser.parse_args()
  518. # 处理task_id和data_path的关系
  519. if args.task_id:
  520. # 如果指定了task_id,覆盖data_path
  521. data_path = Path(resolve_data_path_with_task_id(args.task_id))
  522. print(f"使用task_id构建路径: {args.task_id}")
  523. else:
  524. # 使用指定或默认的data_path
  525. data_path = Path(args.data_path)
  526. # 显示路径解析结果
  527. print(f"\n===== 训练数据路径配置 =====")
  528. try:
  529. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  530. config_value = SCHEMA_TOOLS_CONFIG.get("output_directory", "未配置")
  531. print(f"data_pipeline配置路径: {config_value}")
  532. except ImportError:
  533. print(f"data_pipeline配置: 无法导入")
  534. if args.task_id:
  535. print(f"指定的task_id: {args.task_id}")
  536. print(f"解析后的绝对路径: {os.path.abspath(data_path)}")
  537. print("==============================")
  538. # 设置正确的项目根目录路径
  539. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  540. # 检查嵌入模型连接
  541. check_embedding_model_connection()
  542. # 根据配置的向量数据库类型显示相应信息
  543. vector_db_type = app_config.VECTOR_DB_TYPE.lower()
  544. if vector_db_type == "chromadb":
  545. # 打印ChromaDB相关信息
  546. try:
  547. try:
  548. import chromadb
  549. chroma_version = chromadb.__version__
  550. except ImportError:
  551. chroma_version = "未知"
  552. # 尝试查看当前使用的ChromaDB文件
  553. chroma_file = "chroma.sqlite3" # 默认文件名
  554. # 使用项目根目录作为ChromaDB文件路径
  555. db_file_path = os.path.join(project_root, chroma_file)
  556. if os.path.exists(db_file_path):
  557. file_size = os.path.getsize(db_file_path) / 1024 # KB
  558. print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
  559. else:
  560. print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
  561. # 打印ChromaDB版本
  562. print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
  563. except Exception as e:
  564. print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
  565. elif vector_db_type == "pgvector":
  566. # 打印PgVector相关信息并测试连接
  567. print(f"\n===== PgVector数据库配置 =====")
  568. pg_config = app_config.PGVECTOR_CONFIG
  569. print(f"数据库地址: {pg_config['host']}:{pg_config['port']}")
  570. print(f"数据库名称: {pg_config['dbname']}")
  571. print(f"用户名: {pg_config['user']}")
  572. print("==============================\n")
  573. # 测试PgVector连接
  574. if not check_pgvector_connection():
  575. print("PgVector 数据库连接失败,训练过程终止。")
  576. sys.exit(1)
  577. else:
  578. print(f"\n===== 未知的向量数据库类型: {vector_db_type} =====\n")
  579. # 处理训练文件
  580. process_successful, vector_stats = process_training_files(data_path, args.task_id,
  581. args.backup_vector_tables,
  582. args.truncate_vector_tables,
  583. args.skip_training)
  584. if process_successful:
  585. # 训练结束,刷新和关闭批处理器
  586. print("\n===== 训练完成,处理剩余批次 =====")
  587. flush_training()
  588. shutdown_trainer()
  589. # 验证数据是否成功写入
  590. print("\n===== 验证训练数据 =====")
  591. from core.vanna_llm_factory import create_vanna_instance
  592. vn = create_vanna_instance()
  593. # 根据向量数据库类型执行不同的验证逻辑
  594. try:
  595. training_data = vn.get_training_data()
  596. if training_data is not None and not training_data.empty:
  597. print(f"✓ 已从{vector_db_type.upper()}中检索到 {len(training_data)} 条训练数据进行验证。")
  598. # 显示训练数据类型统计
  599. if 'training_data_type' in training_data.columns:
  600. type_counts = training_data['training_data_type'].value_counts()
  601. print("训练数据类型统计:")
  602. for data_type, count in type_counts.items():
  603. print(f" {data_type}: {count} 条")
  604. elif training_data is not None and training_data.empty:
  605. print(f"⚠ 在{vector_db_type.upper()}中未找到任何训练数据。")
  606. else: # training_data is None
  607. print(f"⚠ 无法从Vanna获取训练数据 (可能返回了None)。请检查{vector_db_type.upper()}连接和Vanna实现。")
  608. except Exception as e:
  609. print(f"✗ 验证训练数据失败: {e}")
  610. print(f"请检查{vector_db_type.upper()}连接和表结构。")
  611. else:
  612. print("\n===== 未能找到或处理任何训练文件,训练过程终止 =====")
  613. # Vector表管理总结
  614. print("\n===== Vector表管理统计 =====")
  615. if vector_stats:
  616. if vector_stats.get("backup_performed", False):
  617. tables_info = vector_stats.get("tables_backed_up", {})
  618. print(f"✓ 备份执行: 成功备份 {len(tables_info)} 个表")
  619. for table_name, info in tables_info.items():
  620. if info.get("success", False):
  621. print(f" - {table_name}: {info['row_count']}行 -> {info['backup_file']} ({info['file_size']})")
  622. else:
  623. print(f" - {table_name}: 备份失败 - {info.get('error', '未知错误')}")
  624. else:
  625. print("- 备份执行: 未执行")
  626. if vector_stats.get("truncate_performed", False):
  627. truncate_info = vector_stats.get("truncate_results", {})
  628. print("✓ 清空执行: langchain_pg_embedding表已清空")
  629. for table_name, info in truncate_info.items():
  630. if info.get("success", False):
  631. print(f" - {table_name}: {info['rows_before']}行 -> 0行")
  632. else:
  633. print(f" - {table_name}: 清空失败 - {info.get('error', '未知错误')}")
  634. else:
  635. print("- 清空执行: 未执行")
  636. print(f"✓ 总耗时: {vector_stats.get('duration', 0):.1f}秒")
  637. if vector_stats.get("errors"):
  638. print(f"⚠ 错误: {'; '.join(vector_stats['errors'])}")
  639. else:
  640. print("- 未执行vector表管理操作")
  641. print("===========================")
  642. # 输出embedding模型信息
  643. print("\n===== Embedding模型信息 =====")
  644. try:
  645. from common.utils import get_current_embedding_config, get_current_model_info
  646. embedding_config = get_current_embedding_config()
  647. model_info = get_current_model_info()
  648. print(f"模型类型: {model_info['embedding_type']}")
  649. print(f"模型名称: {model_info['embedding_model']}")
  650. print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
  651. if 'base_url' in embedding_config:
  652. print(f"API服务: {embedding_config['base_url']}")
  653. except ImportError as e:
  654. print(f"警告: 无法导入配置工具函数: {e}")
  655. # 回退到旧的配置访问方式
  656. embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
  657. print(f"模型名称: {embedding_config.get('model_name', '未知')}")
  658. print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
  659. print(f"API服务: {embedding_config.get('base_url', '未知')}")
  660. # 根据配置显示向量数据库信息
  661. if vector_db_type == "chromadb":
  662. chroma_display_path = os.path.abspath(project_root)
  663. print(f"向量数据库: ChromaDB ({chroma_display_path})")
  664. elif vector_db_type == "pgvector":
  665. pg_config = app_config.PGVECTOR_CONFIG
  666. print(f"向量数据库: PgVector ({pg_config['host']}:{pg_config['port']}/{pg_config['dbname']})")
  667. print("===== 训练流程完成 =====\n")
  668. if __name__ == "__main__":
  669. main()