|
@@ -0,0 +1,473 @@
|
|
|
+# run_training.py
|
|
|
+import os
|
|
|
+import time
|
|
|
+import re
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import requests
|
|
|
+import pandas as pd
|
|
|
+import argparse
|
|
|
+from pathlib import Path
|
|
|
+from sqlalchemy import create_engine
|
|
|
+
|
|
|
+
|
|
|
+from vanna_trainer import (
|
|
|
+ train_ddl,
|
|
|
+ train_documentation,
|
|
|
+ train_sql_example,
|
|
|
+ train_question_sql_pair,
|
|
|
+ flush_training,
|
|
|
+ shutdown_trainer
|
|
|
+)
|
|
|
+
|
|
|
+def check_embedding_model_connection():
|
|
|
+ """检查嵌入模型连接是否可用
|
|
|
+ 如果无法连接到嵌入模型,则终止程序执行
|
|
|
+ Returns:
|
|
|
+ bool: 连接成功返回True,否则终止程序
|
|
|
+ """
|
|
|
+ from embedding_function import test_embedding_connection
|
|
|
+
|
|
|
+ print("正在检查嵌入模型连接...")
|
|
|
+
|
|
|
+ # 使用专门的测试函数进行连接测试
|
|
|
+ test_result = test_embedding_connection()
|
|
|
+
|
|
|
+ if test_result["success"]:
|
|
|
+ print(f"可以继续训练过程。")
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ print(f"\n错误: 无法连接到嵌入模型: {test_result['message']}")
|
|
|
+ print("训练过程终止。请检查配置和API服务可用性。")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+def read_file_by_delimiter(filepath, delimiter="---"):
|
|
|
+ """通用读取:将文件按分隔符切片为多个段落"""
|
|
|
+ with open(filepath, "r", encoding="utf-8") as f:
|
|
|
+ content = f.read()
|
|
|
+ blocks = [block.strip() for block in content.split(delimiter) if block.strip()]
|
|
|
+ return blocks
|
|
|
+
|
|
|
+def read_markdown_file_by_sections(filepath):
|
|
|
+ """专门用于Markdown文件:按标题(#、##、###)分割文档
|
|
|
+
|
|
|
+ Args:
|
|
|
+ filepath (str): Markdown文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ list: 分割后的Markdown章节列表
|
|
|
+ """
|
|
|
+ with open(filepath, "r", encoding="utf-8") as f:
|
|
|
+ content = f.read()
|
|
|
+
|
|
|
+ # 确定文件是否为Markdown
|
|
|
+ is_markdown = filepath.lower().endswith('.md') or filepath.lower().endswith('.markdown')
|
|
|
+
|
|
|
+ if not is_markdown:
|
|
|
+ # 非Markdown文件使用默认的---分隔
|
|
|
+ return read_file_by_delimiter(filepath, "---")
|
|
|
+
|
|
|
+ # 直接按照标题级别分割内容,处理#、##和###
|
|
|
+ sections = []
|
|
|
+
|
|
|
+ # 匹配所有级别的标题(#、##或###开头)
|
|
|
+ header_pattern = r'(?:^|\n)((?:#|##|###)[^#].*?)(?=\n(?:#|##|###)[^#]|\Z)'
|
|
|
+ all_sections = re.findall(header_pattern, content, re.DOTALL)
|
|
|
+
|
|
|
+ for section in all_sections:
|
|
|
+ section = section.strip()
|
|
|
+ if section:
|
|
|
+ sections.append(section)
|
|
|
+
|
|
|
+ # 处理没有匹配到标题的情况
|
|
|
+ if not sections and content.strip():
|
|
|
+ sections = [content.strip()]
|
|
|
+
|
|
|
+ return sections
|
|
|
+
|
|
|
+def train_ddl_statements(ddl_file):
|
|
|
+ """训练DDL语句
|
|
|
+ Args:
|
|
|
+ ddl_file (str): DDL文件路径
|
|
|
+ """
|
|
|
+ print(f"开始训练 DDL: {ddl_file}")
|
|
|
+ if not os.path.exists(ddl_file):
|
|
|
+ print(f"DDL 文件不存在: {ddl_file}")
|
|
|
+ return
|
|
|
+ for idx, ddl in enumerate(read_file_by_delimiter(ddl_file, ";"), start=1):
|
|
|
+ try:
|
|
|
+ print(f"\n DDL 训练 {idx}")
|
|
|
+ train_ddl(ddl)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"错误:DDL #{idx} - {e}")
|
|
|
+
|
|
|
+def train_documentation_blocks(doc_file):
|
|
|
+ """训练文档块
|
|
|
+ Args:
|
|
|
+ doc_file (str): 文档文件路径
|
|
|
+ """
|
|
|
+ print(f"开始训练 文档: {doc_file}")
|
|
|
+ if not os.path.exists(doc_file):
|
|
|
+ print(f"文档文件不存在: {doc_file}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 检查是否为Markdown文件
|
|
|
+ is_markdown = doc_file.lower().endswith('.md') or doc_file.lower().endswith('.markdown')
|
|
|
+
|
|
|
+ if is_markdown:
|
|
|
+ # 使用Markdown专用分割器
|
|
|
+ sections = read_markdown_file_by_sections(doc_file)
|
|
|
+ print(f" Markdown文档已分割为 {len(sections)} 个章节")
|
|
|
+
|
|
|
+ for idx, section in enumerate(sections, start=1):
|
|
|
+ try:
|
|
|
+ section_title = section.split('\n', 1)[0].strip()
|
|
|
+ print(f"\n Markdown章节训练 {idx}: {section_title}")
|
|
|
+
|
|
|
+ # 检查部分长度并提供警告
|
|
|
+ if len(section) > 2000:
|
|
|
+ print(f" 章节 {idx} 长度为 {len(section)} 字符,接近API限制(2048)")
|
|
|
+
|
|
|
+ train_documentation(section)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:章节 #{idx} - {e}")
|
|
|
+ else:
|
|
|
+ # 非Markdown文件使用传统的---分隔
|
|
|
+ for idx, doc in enumerate(read_file_by_delimiter(doc_file, "---"), start=1):
|
|
|
+ try:
|
|
|
+ print(f"\n 文档训练 {idx}")
|
|
|
+ train_documentation(doc)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:文档 #{idx} - {e}")
|
|
|
+
|
|
|
+def train_sql_examples(sql_file):
|
|
|
+ """训练SQL示例
|
|
|
+ Args:
|
|
|
+ sql_file (str): SQL示例文件路径
|
|
|
+ """
|
|
|
+ print(f" 开始训练 SQL 示例: {sql_file}")
|
|
|
+ if not os.path.exists(sql_file):
|
|
|
+ print(f" SQL 示例文件不存在: {sql_file}")
|
|
|
+ return
|
|
|
+ for idx, sql in enumerate(read_file_by_delimiter(sql_file, ";"), start=1):
|
|
|
+ try:
|
|
|
+ print(f"\n SQL 示例训练 {idx}")
|
|
|
+ train_sql_example(sql)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:SQL #{idx} - {e}")
|
|
|
+
|
|
|
+def train_question_sql_pairs(qs_file):
|
|
|
+ """训练问答对
|
|
|
+ Args:
|
|
|
+ qs_file (str): 问答对文件路径
|
|
|
+ """
|
|
|
+ print(f" 开始训练 问答对: {qs_file}")
|
|
|
+ if not os.path.exists(qs_file):
|
|
|
+ print(f" 问答文件不存在: {qs_file}")
|
|
|
+ return
|
|
|
+ try:
|
|
|
+ with open(qs_file, "r", encoding="utf-8") as f:
|
|
|
+ lines = f.readlines()
|
|
|
+ for idx, line in enumerate(lines, start=1):
|
|
|
+ if "::" not in line:
|
|
|
+ continue
|
|
|
+ question, sql = line.strip().split("::", 1)
|
|
|
+ print(f"\n 问答训练 {idx}")
|
|
|
+ train_question_sql_pair(question.strip(), sql.strip())
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:问答训练 - {e}")
|
|
|
+
|
|
|
+def train_formatted_question_sql_pairs(formatted_file):
|
|
|
+ """训练格式化的问答对文件
|
|
|
+ 支持两种格式:
|
|
|
+ 1. Question: xxx\nSQL: xxx (单行SQL)
|
|
|
+ 2. Question: xxx\nSQL:\nxxx\nxxx (多行SQL)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ formatted_file (str): 格式化问答对文件路径
|
|
|
+ """
|
|
|
+ print(f" 开始训练 格式化问答对: {formatted_file}")
|
|
|
+ if not os.path.exists(formatted_file):
|
|
|
+ print(f" 格式化问答文件不存在: {formatted_file}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 读取整个文件内容
|
|
|
+ with open(formatted_file, "r", encoding="utf-8") as f:
|
|
|
+ content = f.read()
|
|
|
+
|
|
|
+ # 按双空行分割不同的问答对
|
|
|
+ # 使用更精确的分隔符,避免误识别
|
|
|
+ pairs = []
|
|
|
+ blocks = content.split("\n\nQuestion:")
|
|
|
+
|
|
|
+ # 处理第一块(可能没有前导的"\n\nQuestion:")
|
|
|
+ first_block = blocks[0]
|
|
|
+ if first_block.strip().startswith("Question:"):
|
|
|
+ pairs.append(first_block.strip())
|
|
|
+ elif "Question:" in first_block:
|
|
|
+ # 处理文件开头没有Question:的情况
|
|
|
+ question_start = first_block.find("Question:")
|
|
|
+ pairs.append(first_block[question_start:].strip())
|
|
|
+
|
|
|
+ # 处理其余块
|
|
|
+ for block in blocks[1:]:
|
|
|
+ pairs.append("Question:" + block.strip())
|
|
|
+
|
|
|
+ # 处理每个问答对
|
|
|
+ successfully_processed = 0
|
|
|
+ for idx, pair in enumerate(pairs, start=1):
|
|
|
+ try:
|
|
|
+ if "Question:" not in pair or "SQL:" not in pair:
|
|
|
+ print(f" 跳过不符合格式的对 #{idx}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 提取问题部分
|
|
|
+ question_start = pair.find("Question:") + len("Question:")
|
|
|
+ sql_start = pair.find("SQL:", question_start)
|
|
|
+
|
|
|
+ if sql_start == -1:
|
|
|
+ print(f" SQL部分未找到,跳过对 #{idx}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ question = pair[question_start:sql_start].strip()
|
|
|
+
|
|
|
+ # 提取SQL部分(支持多行)
|
|
|
+ sql_part = pair[sql_start + len("SQL:"):].strip()
|
|
|
+
|
|
|
+ # 检查是否存在下一个Question标记(防止解析错误)
|
|
|
+ next_question = pair.find("Question:", sql_start)
|
|
|
+ if next_question != -1:
|
|
|
+ sql_part = pair[sql_start + len("SQL:"):next_question].strip()
|
|
|
+
|
|
|
+ if not question or not sql_part:
|
|
|
+ print(f" 问题或SQL为空,跳过对 #{idx}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 训练问答对
|
|
|
+ print(f"\n格式化问答训练 {idx}")
|
|
|
+ print(f"问题: {question}")
|
|
|
+ print(f"SQL: {sql_part}")
|
|
|
+ train_question_sql_pair(question, sql_part)
|
|
|
+ successfully_processed += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:格式化问答训练对 #{idx} - {e}")
|
|
|
+
|
|
|
+ print(f"格式化问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(pairs)} 对)")
|
|
|
+
|
|
|
+def train_json_question_sql_pairs(json_file):
|
|
|
+ """训练JSON格式的问答对
|
|
|
+
|
|
|
+ Args:
|
|
|
+ json_file (str): JSON格式问答对文件路径
|
|
|
+ """
|
|
|
+ print(f" 开始训练 JSON格式问答对: {json_file}")
|
|
|
+ if not os.path.exists(json_file):
|
|
|
+ print(f" JSON问答文件不存在: {json_file}")
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 读取JSON文件
|
|
|
+ with open(json_file, "r", encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+
|
|
|
+ # 确保数据是列表格式
|
|
|
+ if not isinstance(data, list):
|
|
|
+ print(f" 错误: JSON文件格式不正确,应为问答对列表")
|
|
|
+ return
|
|
|
+
|
|
|
+ successfully_processed = 0
|
|
|
+ for idx, pair in enumerate(data, start=1):
|
|
|
+ try:
|
|
|
+ # 检查问答对格式
|
|
|
+ if not isinstance(pair, dict) or "question" not in pair or "sql" not in pair:
|
|
|
+ print(f" 跳过不符合格式的对 #{idx}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ question = pair["question"].strip()
|
|
|
+ sql = pair["sql"].strip()
|
|
|
+
|
|
|
+ if not question or not sql:
|
|
|
+ print(f" 问题或SQL为空,跳过对 #{idx}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 训练问答对
|
|
|
+ print(f"\n JSON格式问答训练 {idx}")
|
|
|
+ print(f"问题: {question}")
|
|
|
+ print(f"SQL: {sql}")
|
|
|
+ train_question_sql_pair(question, sql)
|
|
|
+ successfully_processed += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:JSON问答训练对 #{idx} - {e}")
|
|
|
+
|
|
|
+ print(f"JSON格式问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(data)} 对)")
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ print(f" 错误:JSON解析失败 - {e}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f" 错误:处理JSON问答训练 - {e}")
|
|
|
+
|
|
|
+def process_training_files(data_path):
|
|
|
+ """处理指定路径下的所有训练文件
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data_path (str): 训练数据目录路径
|
|
|
+ """
|
|
|
+ print(f"\n===== 扫描训练数据目录: {os.path.abspath(data_path)} =====")
|
|
|
+
|
|
|
+ # 检查目录是否存在
|
|
|
+ if not os.path.exists(data_path):
|
|
|
+ print(f"错误: 训练数据目录不存在: {data_path}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 初始化统计计数器
|
|
|
+ stats = {
|
|
|
+ "ddl": 0,
|
|
|
+ "documentation": 0,
|
|
|
+ "sql_example": 0,
|
|
|
+ "question_sql_formatted": 0,
|
|
|
+ "question_sql_json": 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # 递归遍历目录中的所有文件
|
|
|
+ for root, _, files in os.walk(data_path):
|
|
|
+ for file in files:
|
|
|
+ file_path = os.path.join(root, file)
|
|
|
+ file_lower = file.lower()
|
|
|
+
|
|
|
+ # 根据文件类型调用相应的处理函数
|
|
|
+ try:
|
|
|
+ if file_lower.endswith(".ddl"):
|
|
|
+ print(f"\n处理DDL文件: {file_path}")
|
|
|
+ train_ddl_statements(file_path)
|
|
|
+ stats["ddl"] += 1
|
|
|
+
|
|
|
+ elif file_lower.endswith(".md") or file_lower.endswith(".markdown"):
|
|
|
+ print(f"\n处理文档文件: {file_path}")
|
|
|
+ train_documentation_blocks(file_path)
|
|
|
+ stats["documentation"] += 1
|
|
|
+
|
|
|
+ elif file_lower.endswith("_pair.json") or file_lower.endswith("_pairs.json"):
|
|
|
+ print(f"\n处理JSON问答对文件: {file_path}")
|
|
|
+ train_json_question_sql_pairs(file_path)
|
|
|
+ stats["question_sql_json"] += 1
|
|
|
+
|
|
|
+ elif file_lower.endswith("_sql_pair.sql") or file_lower.endswith("_sql_pairs.sql"):
|
|
|
+ print(f"\n处理格式化问答对文件: {file_path}")
|
|
|
+ train_formatted_question_sql_pairs(file_path)
|
|
|
+ stats["question_sql_formatted"] += 1
|
|
|
+
|
|
|
+ elif file_lower.endswith(".sql") and not (file_lower.endswith("_sql_pair.sql") or file_lower.endswith("_sql_pairs.sql")):
|
|
|
+ print(f"\n处理SQL示例文件: {file_path}")
|
|
|
+ train_sql_examples(file_path)
|
|
|
+ stats["sql_example"] += 1
|
|
|
+ except Exception as e:
|
|
|
+ print(f"处理文件 {file_path} 时出错: {e}")
|
|
|
+
|
|
|
+ # 打印处理统计
|
|
|
+ print("\n===== 训练文件处理统计 =====")
|
|
|
+ print(f"DDL文件: {stats['ddl']}个")
|
|
|
+ print(f"文档文件: {stats['documentation']}个")
|
|
|
+ print(f"SQL示例文件: {stats['sql_example']}个")
|
|
|
+ print(f"格式化问答对文件: {stats['question_sql_formatted']}个")
|
|
|
+ print(f"JSON问答对文件: {stats['question_sql_json']}个")
|
|
|
+
|
|
|
+ total_files = sum(stats.values())
|
|
|
+ if total_files == 0:
|
|
|
+ print(f"警告: 在目录 {data_path} 中未找到任何可训练的文件")
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数:配置和运行训练流程"""
|
|
|
+
|
|
|
+ # 先导入所需模块
|
|
|
+ import os
|
|
|
+ import app_config
|
|
|
+
|
|
|
+ # 解析命令行参数
|
|
|
+ parser = argparse.ArgumentParser(description='训练Vanna NL2SQL模型')
|
|
|
+ parser.add_argument('--data_path', type=str, default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data'),
|
|
|
+ help='训练数据目录路径 (默认: training/data)')
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 使用Path对象处理路径以确保跨平台兼容性
|
|
|
+ data_path = Path(args.data_path)
|
|
|
+
|
|
|
+ # 设置正确的项目根目录路径
|
|
|
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
+
|
|
|
+ # 检查嵌入模型连接
|
|
|
+ check_embedding_model_connection()
|
|
|
+
|
|
|
+ # 打印ChromaDB相关信息
|
|
|
+ try:
|
|
|
+ try:
|
|
|
+ import chromadb
|
|
|
+ chroma_version = chromadb.__version__
|
|
|
+ except ImportError:
|
|
|
+ chroma_version = "未知"
|
|
|
+
|
|
|
+ # 尝试查看当前使用的ChromaDB文件
|
|
|
+ chroma_file = "chroma.sqlite3" # 默认文件名
|
|
|
+
|
|
|
+ # 使用项目根目录作为ChromaDB文件路径
|
|
|
+ db_file_path = os.path.join(project_root, chroma_file)
|
|
|
+
|
|
|
+ if os.path.exists(db_file_path):
|
|
|
+ file_size = os.path.getsize(db_file_path) / 1024 # KB
|
|
|
+ print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
|
|
|
+ else:
|
|
|
+ print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
|
|
|
+
|
|
|
+ # 打印ChromaDB版本
|
|
|
+ print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
|
|
|
+
|
|
|
+ # 处理训练文件
|
|
|
+ process_successful = process_training_files(data_path)
|
|
|
+
|
|
|
+ if process_successful:
|
|
|
+ # 训练结束,刷新和关闭批处理器
|
|
|
+ print("\n===== 训练完成,处理剩余批次 =====")
|
|
|
+ flush_training()
|
|
|
+ shutdown_trainer()
|
|
|
+
|
|
|
+ # 验证数据是否成功写入
|
|
|
+ print("\n===== 验证训练数据 =====")
|
|
|
+ from vanna_llm_factory import create_vanna_instance
|
|
|
+ vn = create_vanna_instance()
|
|
|
+
|
|
|
+ # 根据向量数据库类型执行不同的验证逻辑
|
|
|
+ # 由于已确定只使用ChromaDB,简化这部分逻辑
|
|
|
+ try:
|
|
|
+ training_data = vn.get_training_data()
|
|
|
+ if training_data is not None and not training_data.empty:
|
|
|
+ # get_training_data 内部通常会打印数量,这里可以补充一个总结
|
|
|
+ print(f"已从ChromaDB中检索到 {len(training_data)} 条训练数据进行验证。")
|
|
|
+ elif training_data is not None and training_data.empty:
|
|
|
+ print("在ChromaDB中未找到任何训练数据。")
|
|
|
+ else: # training_data is None
|
|
|
+ print("无法从Vanna获取训练数据 (可能返回了None)。请检查连接和Vanna实现。")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"验证训练数据失败: {e}")
|
|
|
+ print("请检查ChromaDB连接和表结构。")
|
|
|
+ else:
|
|
|
+ print("\n===== 未能找到或处理任何训练文件,训练过程终止 =====")
|
|
|
+
|
|
|
+ # 输出embedding模型信息
|
|
|
+ print("\n===== Embedding模型信息 =====")
|
|
|
+ print(f"模型名称: {app_config.EMBEDDING_CONFIG.get('model_name')}")
|
|
|
+ print(f"向量维度: {app_config.EMBEDDING_CONFIG.get('embedding_dimension')}")
|
|
|
+ print(f"API服务: {app_config.EMBEDDING_CONFIG.get('base_url')}")
|
|
|
+ # 打印ChromaDB路径信息
|
|
|
+ chroma_display_path = os.path.abspath(project_root)
|
|
|
+ print(f"向量数据库: ChromaDB ({chroma_display_path})")
|
|
|
+ print("===== 训练流程完成 =====\n")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|