run_training.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # run_training.py
  2. import os
  3. import time
  4. import re
  5. import json
  6. import sys
  7. import requests
  8. import pandas as pd
  9. import argparse
  10. from pathlib import Path
  11. from sqlalchemy import create_engine
  12. from vanna_trainer import (
  13. train_ddl,
  14. train_documentation,
  15. train_sql_example,
  16. train_question_sql_pair,
  17. flush_training,
  18. shutdown_trainer
  19. )
  20. def check_embedding_model_connection():
  21. """检查嵌入模型连接是否可用
  22. 如果无法连接到嵌入模型,则终止程序执行
  23. Returns:
  24. bool: 连接成功返回True,否则终止程序
  25. """
  26. from core.embedding_function import test_embedding_connection
  27. print("正在检查嵌入模型连接...")
  28. # 使用专门的测试函数进行连接测试
  29. test_result = test_embedding_connection()
  30. if test_result["success"]:
  31. print(f"可以继续训练过程。")
  32. return True
  33. else:
  34. print(f"\n错误: 无法连接到嵌入模型: {test_result['message']}")
  35. print("训练过程终止。请检查配置和API服务可用性。")
  36. sys.exit(1)
  37. def read_file_by_delimiter(filepath, delimiter="---"):
  38. """通用读取:将文件按分隔符切片为多个段落"""
  39. with open(filepath, "r", encoding="utf-8") as f:
  40. content = f.read()
  41. blocks = [block.strip() for block in content.split(delimiter) if block.strip()]
  42. return blocks
  43. def read_markdown_file_by_sections(filepath):
  44. """专门用于Markdown文件:按标题(#、##、###)分割文档
  45. Args:
  46. filepath (str): Markdown文件路径
  47. Returns:
  48. list: 分割后的Markdown章节列表
  49. """
  50. with open(filepath, "r", encoding="utf-8") as f:
  51. content = f.read()
  52. # 确定文件是否为Markdown
  53. is_markdown = filepath.lower().endswith('.md') or filepath.lower().endswith('.markdown')
  54. if not is_markdown:
  55. # 非Markdown文件使用默认的---分隔
  56. return read_file_by_delimiter(filepath, "---")
  57. # 直接按照标题级别分割内容,处理#、##和###
  58. sections = []
  59. # 匹配所有级别的标题(#、##或###开头)
  60. header_pattern = r'(?:^|\n)((?:#|##|###)[^#].*?)(?=\n(?:#|##|###)[^#]|\Z)'
  61. all_sections = re.findall(header_pattern, content, re.DOTALL)
  62. for section in all_sections:
  63. section = section.strip()
  64. if section:
  65. sections.append(section)
  66. # 处理没有匹配到标题的情况
  67. if not sections and content.strip():
  68. sections = [content.strip()]
  69. return sections
  70. def train_ddl_statements(ddl_file):
  71. """训练DDL语句
  72. Args:
  73. ddl_file (str): DDL文件路径
  74. """
  75. print(f"开始训练 DDL: {ddl_file}")
  76. if not os.path.exists(ddl_file):
  77. print(f"DDL 文件不存在: {ddl_file}")
  78. return
  79. for idx, ddl in enumerate(read_file_by_delimiter(ddl_file, ";"), start=1):
  80. try:
  81. print(f"\n DDL 训练 {idx}")
  82. train_ddl(ddl)
  83. except Exception as e:
  84. print(f"错误:DDL #{idx} - {e}")
  85. def train_documentation_blocks(doc_file):
  86. """训练文档块
  87. Args:
  88. doc_file (str): 文档文件路径
  89. """
  90. print(f"开始训练 文档: {doc_file}")
  91. if not os.path.exists(doc_file):
  92. print(f"文档文件不存在: {doc_file}")
  93. return
  94. # 检查是否为Markdown文件
  95. is_markdown = doc_file.lower().endswith('.md') or doc_file.lower().endswith('.markdown')
  96. if is_markdown:
  97. # 使用Markdown专用分割器
  98. sections = read_markdown_file_by_sections(doc_file)
  99. print(f" Markdown文档已分割为 {len(sections)} 个章节")
  100. for idx, section in enumerate(sections, start=1):
  101. try:
  102. section_title = section.split('\n', 1)[0].strip()
  103. print(f"\n Markdown章节训练 {idx}: {section_title}")
  104. # 检查部分长度并提供警告
  105. if len(section) > 2000:
  106. print(f" 章节 {idx} 长度为 {len(section)} 字符,接近API限制(2048)")
  107. train_documentation(section)
  108. except Exception as e:
  109. print(f" 错误:章节 #{idx} - {e}")
  110. else:
  111. # 非Markdown文件使用传统的---分隔
  112. for idx, doc in enumerate(read_file_by_delimiter(doc_file, "---"), start=1):
  113. try:
  114. print(f"\n 文档训练 {idx}")
  115. train_documentation(doc)
  116. except Exception as e:
  117. print(f" 错误:文档 #{idx} - {e}")
  118. def train_sql_examples(sql_file):
  119. """训练SQL示例
  120. Args:
  121. sql_file (str): SQL示例文件路径
  122. """
  123. print(f" 开始训练 SQL 示例: {sql_file}")
  124. if not os.path.exists(sql_file):
  125. print(f" SQL 示例文件不存在: {sql_file}")
  126. return
  127. for idx, sql in enumerate(read_file_by_delimiter(sql_file, ";"), start=1):
  128. try:
  129. print(f"\n SQL 示例训练 {idx}")
  130. train_sql_example(sql)
  131. except Exception as e:
  132. print(f" 错误:SQL #{idx} - {e}")
  133. def train_question_sql_pairs(qs_file):
  134. """训练问答对
  135. Args:
  136. qs_file (str): 问答对文件路径
  137. """
  138. print(f" 开始训练 问答对: {qs_file}")
  139. if not os.path.exists(qs_file):
  140. print(f" 问答文件不存在: {qs_file}")
  141. return
  142. try:
  143. with open(qs_file, "r", encoding="utf-8") as f:
  144. lines = f.readlines()
  145. for idx, line in enumerate(lines, start=1):
  146. if "::" not in line:
  147. continue
  148. question, sql = line.strip().split("::", 1)
  149. print(f"\n 问答训练 {idx}")
  150. train_question_sql_pair(question.strip(), sql.strip())
  151. except Exception as e:
  152. print(f" 错误:问答训练 - {e}")
  153. def train_formatted_question_sql_pairs(formatted_file):
  154. """训练格式化的问答对文件
  155. 支持两种格式:
  156. 1. Question: xxx\nSQL: xxx (单行SQL)
  157. 2. Question: xxx\nSQL:\nxxx\nxxx (多行SQL)
  158. Args:
  159. formatted_file (str): 格式化问答对文件路径
  160. """
  161. print(f" 开始训练 格式化问答对: {formatted_file}")
  162. if not os.path.exists(formatted_file):
  163. print(f" 格式化问答文件不存在: {formatted_file}")
  164. return
  165. # 读取整个文件内容
  166. with open(formatted_file, "r", encoding="utf-8") as f:
  167. content = f.read()
  168. # 按双空行分割不同的问答对
  169. # 使用更精确的分隔符,避免误识别
  170. pairs = []
  171. blocks = content.split("\n\nQuestion:")
  172. # 处理第一块(可能没有前导的"\n\nQuestion:")
  173. first_block = blocks[0]
  174. if first_block.strip().startswith("Question:"):
  175. pairs.append(first_block.strip())
  176. elif "Question:" in first_block:
  177. # 处理文件开头没有Question:的情况
  178. question_start = first_block.find("Question:")
  179. pairs.append(first_block[question_start:].strip())
  180. # 处理其余块
  181. for block in blocks[1:]:
  182. pairs.append("Question:" + block.strip())
  183. # 处理每个问答对
  184. successfully_processed = 0
  185. for idx, pair in enumerate(pairs, start=1):
  186. try:
  187. if "Question:" not in pair or "SQL:" not in pair:
  188. print(f" 跳过不符合格式的对 #{idx}")
  189. continue
  190. # 提取问题部分
  191. question_start = pair.find("Question:") + len("Question:")
  192. sql_start = pair.find("SQL:", question_start)
  193. if sql_start == -1:
  194. print(f" SQL部分未找到,跳过对 #{idx}")
  195. continue
  196. question = pair[question_start:sql_start].strip()
  197. # 提取SQL部分(支持多行)
  198. sql_part = pair[sql_start + len("SQL:"):].strip()
  199. # 检查是否存在下一个Question标记(防止解析错误)
  200. next_question = pair.find("Question:", sql_start)
  201. if next_question != -1:
  202. sql_part = pair[sql_start + len("SQL:"):next_question].strip()
  203. if not question or not sql_part:
  204. print(f" 问题或SQL为空,跳过对 #{idx}")
  205. continue
  206. # 训练问答对
  207. print(f"\n格式化问答训练 {idx}")
  208. print(f"问题: {question}")
  209. print(f"SQL: {sql_part}")
  210. train_question_sql_pair(question, sql_part)
  211. successfully_processed += 1
  212. except Exception as e:
  213. print(f" 错误:格式化问答训练对 #{idx} - {e}")
  214. print(f"格式化问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(pairs)} 对)")
  215. def train_json_question_sql_pairs(json_file):
  216. """训练JSON格式的问答对
  217. Args:
  218. json_file (str): JSON格式问答对文件路径
  219. """
  220. print(f" 开始训练 JSON格式问答对: {json_file}")
  221. if not os.path.exists(json_file):
  222. print(f" JSON问答文件不存在: {json_file}")
  223. return
  224. try:
  225. # 读取JSON文件
  226. with open(json_file, "r", encoding="utf-8") as f:
  227. data = json.load(f)
  228. # 确保数据是列表格式
  229. if not isinstance(data, list):
  230. print(f" 错误: JSON文件格式不正确,应为问答对列表")
  231. return
  232. successfully_processed = 0
  233. for idx, pair in enumerate(data, start=1):
  234. try:
  235. # 检查问答对格式
  236. if not isinstance(pair, dict) or "question" not in pair or "sql" not in pair:
  237. print(f" 跳过不符合格式的对 #{idx}")
  238. continue
  239. question = pair["question"].strip()
  240. sql = pair["sql"].strip()
  241. if not question or not sql:
  242. print(f" 问题或SQL为空,跳过对 #{idx}")
  243. continue
  244. # 训练问答对
  245. print(f"\n JSON格式问答训练 {idx}")
  246. print(f"问题: {question}")
  247. print(f"SQL: {sql}")
  248. train_question_sql_pair(question, sql)
  249. successfully_processed += 1
  250. except Exception as e:
  251. print(f" 错误:JSON问答训练对 #{idx} - {e}")
  252. print(f"JSON格式问答训练完成,共成功处理 {successfully_processed} 对问答(总计 {len(data)} 对)")
  253. except json.JSONDecodeError as e:
  254. print(f" 错误:JSON解析失败 - {e}")
  255. except Exception as e:
  256. print(f" 错误:处理JSON问答训练 - {e}")
  257. def process_training_files(data_path):
  258. """处理指定路径下的所有训练文件
  259. Args:
  260. data_path (str): 训练数据目录路径
  261. """
  262. print(f"\n===== 扫描训练数据目录: {os.path.abspath(data_path)} =====")
  263. # 检查目录是否存在
  264. if not os.path.exists(data_path):
  265. print(f"错误: 训练数据目录不存在: {data_path}")
  266. return False
  267. # 初始化统计计数器
  268. stats = {
  269. "ddl": 0,
  270. "documentation": 0,
  271. "sql_example": 0,
  272. "question_sql_formatted": 0,
  273. "question_sql_json": 0
  274. }
  275. # 只扫描指定目录下的直接文件,不扫描子目录
  276. try:
  277. items = os.listdir(data_path)
  278. for item in items:
  279. item_path = os.path.join(data_path, item)
  280. # 只处理文件,跳过目录
  281. if not os.path.isfile(item_path):
  282. print(f"跳过子目录: {item}")
  283. continue
  284. file_lower = item.lower()
  285. # 根据文件类型调用相应的处理函数
  286. try:
  287. if file_lower.endswith(".ddl"):
  288. print(f"\n处理DDL文件: {item_path}")
  289. train_ddl_statements(item_path)
  290. stats["ddl"] += 1
  291. elif file_lower.endswith(".md") or file_lower.endswith(".markdown"):
  292. print(f"\n处理文档文件: {item_path}")
  293. train_documentation_blocks(item_path)
  294. stats["documentation"] += 1
  295. elif file_lower.endswith("_pair.json") or file_lower.endswith("_pairs.json"):
  296. print(f"\n处理JSON问答对文件: {item_path}")
  297. train_json_question_sql_pairs(item_path)
  298. stats["question_sql_json"] += 1
  299. elif file_lower.endswith("_pair.sql") or file_lower.endswith("_pairs.sql"):
  300. print(f"\n处理格式化问答对文件: {item_path}")
  301. train_formatted_question_sql_pairs(item_path)
  302. stats["question_sql_formatted"] += 1
  303. elif file_lower.endswith(".sql") and not (file_lower.endswith("_pair.sql") or file_lower.endswith("_pairs.sql")):
  304. print(f"\n处理SQL示例文件: {item_path}")
  305. train_sql_examples(item_path)
  306. stats["sql_example"] += 1
  307. else:
  308. print(f"跳过不支持的文件类型: {item}")
  309. except Exception as e:
  310. print(f"处理文件 {item_path} 时出错: {e}")
  311. except OSError as e:
  312. print(f"读取目录失败: {e}")
  313. return False
  314. # 打印处理统计
  315. print("\n===== 训练文件处理统计 =====")
  316. print(f"DDL文件: {stats['ddl']}个")
  317. print(f"文档文件: {stats['documentation']}个")
  318. print(f"SQL示例文件: {stats['sql_example']}个")
  319. print(f"格式化问答对文件: {stats['question_sql_formatted']}个")
  320. print(f"JSON问答对文件: {stats['question_sql_json']}个")
  321. total_files = sum(stats.values())
  322. if total_files == 0:
  323. print(f"警告: 在目录 {data_path} 中未找到任何可训练的文件")
  324. return False
  325. return True
  326. def check_pgvector_connection():
  327. """检查 PgVector 数据库连接是否可用
  328. Returns:
  329. bool: 连接成功返回True,否则返回False
  330. """
  331. import app_config
  332. from sqlalchemy import create_engine, text
  333. try:
  334. # 构建连接字符串
  335. pg_config = app_config.PGVECTOR_CONFIG
  336. connection_string = f"postgresql://{pg_config['user']}:{pg_config['password']}@{pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}"
  337. print(f"正在测试 PgVector 数据库连接...")
  338. print(f"连接地址: {pg_config['host']}:{pg_config['port']}/{pg_config['dbname']}")
  339. # 创建数据库引擎并测试连接
  340. engine = create_engine(connection_string)
  341. with engine.connect() as connection:
  342. # 测试基本连接
  343. result = connection.execute(text("SELECT 1"))
  344. result.fetchone()
  345. # 检查是否安装了 pgvector 扩展
  346. try:
  347. result = connection.execute(text("SELECT extname FROM pg_extension WHERE extname = 'vector'"))
  348. extension_exists = result.fetchone() is not None
  349. if extension_exists:
  350. print("✓ PgVector 扩展已安装")
  351. else:
  352. print("⚠ 警告: PgVector 扩展未安装,请确保已安装 pgvector 扩展")
  353. except Exception as ext_e:
  354. print(f"⚠ 无法检查 pgvector 扩展状态: {ext_e}")
  355. # 检查训练数据表是否存在
  356. try:
  357. result = connection.execute(text("SELECT tablename FROM pg_tables WHERE tablename = 'langchain_pg_embedding'"))
  358. table_exists = result.fetchone() is not None
  359. if table_exists:
  360. # 获取表中的记录数
  361. result = connection.execute(text("SELECT COUNT(*) FROM langchain_pg_embedding"))
  362. count = result.fetchone()[0]
  363. print(f"✓ 训练数据表存在,当前包含 {count} 条记录")
  364. else:
  365. print("ℹ 训练数据表尚未创建(首次训练时会自动创建)")
  366. except Exception as table_e:
  367. print(f"⚠ 无法检查训练数据表状态: {table_e}")
  368. print("✓ PgVector 数据库连接测试成功")
  369. return True
  370. except Exception as e:
  371. print(f"✗ PgVector 数据库连接失败: {e}")
  372. return False
  373. def main():
  374. """主函数:配置和运行训练流程"""
  375. # 先导入所需模块
  376. import os
  377. import app_config
  378. # 解析命令行参数
  379. parser = argparse.ArgumentParser(description='训练Vanna NL2SQL模型')
  380. # 获取默认路径并进行智能处理
  381. def resolve_training_data_path():
  382. """智能解析训练数据路径"""
  383. config_path = getattr(app_config, 'TRAINING_DATA_PATH', './training/data')
  384. # 如果是绝对路径,直接返回
  385. if os.path.isabs(config_path):
  386. return config_path
  387. # 如果以 . 开头,相对于项目根目录解析
  388. if config_path.startswith('./') or config_path.startswith('../'):
  389. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  390. return os.path.join(project_root, config_path)
  391. # 其他情况,相对于项目根目录
  392. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  393. return os.path.join(project_root, config_path)
  394. default_path = resolve_training_data_path()
  395. parser.add_argument('--data_path', type=str, default=default_path,
  396. help='训练数据目录路径 (默认: 从app_config.TRAINING_DATA_PATH)')
  397. args = parser.parse_args()
  398. # 使用Path对象处理路径以确保跨平台兼容性
  399. data_path = Path(args.data_path)
  400. # 显示路径解析结果
  401. print(f"\n===== 训练数据路径配置 =====")
  402. print(f"配置文件中的路径: {getattr(app_config, 'TRAINING_DATA_PATH', '未配置')}")
  403. print(f"解析后的绝对路径: {os.path.abspath(data_path)}")
  404. print("==============================")
  405. # 设置正确的项目根目录路径
  406. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  407. # 检查嵌入模型连接
  408. check_embedding_model_connection()
  409. # 根据配置的向量数据库类型显示相应信息
  410. vector_db_type = app_config.VECTOR_DB_TYPE.lower()
  411. if vector_db_type == "chromadb":
  412. # 打印ChromaDB相关信息
  413. try:
  414. try:
  415. import chromadb
  416. chroma_version = chromadb.__version__
  417. except ImportError:
  418. chroma_version = "未知"
  419. # 尝试查看当前使用的ChromaDB文件
  420. chroma_file = "chroma.sqlite3" # 默认文件名
  421. # 使用项目根目录作为ChromaDB文件路径
  422. db_file_path = os.path.join(project_root, chroma_file)
  423. if os.path.exists(db_file_path):
  424. file_size = os.path.getsize(db_file_path) / 1024 # KB
  425. print(f"\n===== ChromaDB数据库: {os.path.abspath(db_file_path)} (大小: {file_size:.2f} KB) =====")
  426. else:
  427. print(f"\n===== 未找到ChromaDB数据库文件于: {os.path.abspath(db_file_path)} =====")
  428. # 打印ChromaDB版本
  429. print(f"===== ChromaDB客户端库版本: {chroma_version} =====\n")
  430. except Exception as e:
  431. print(f"\n===== 无法获取ChromaDB信息: {e} =====\n")
  432. elif vector_db_type == "pgvector":
  433. # 打印PgVector相关信息并测试连接
  434. print(f"\n===== PgVector数据库配置 =====")
  435. pg_config = app_config.PGVECTOR_CONFIG
  436. print(f"数据库地址: {pg_config['host']}:{pg_config['port']}")
  437. print(f"数据库名称: {pg_config['dbname']}")
  438. print(f"用户名: {pg_config['user']}")
  439. print("==============================\n")
  440. # 测试PgVector连接
  441. if not check_pgvector_connection():
  442. print("PgVector 数据库连接失败,训练过程终止。")
  443. sys.exit(1)
  444. else:
  445. print(f"\n===== 未知的向量数据库类型: {vector_db_type} =====\n")
  446. # 处理训练文件
  447. process_successful = process_training_files(data_path)
  448. if process_successful:
  449. # 训练结束,刷新和关闭批处理器
  450. print("\n===== 训练完成,处理剩余批次 =====")
  451. flush_training()
  452. shutdown_trainer()
  453. # 验证数据是否成功写入
  454. print("\n===== 验证训练数据 =====")
  455. from core.vanna_llm_factory import create_vanna_instance
  456. vn = create_vanna_instance()
  457. # 根据向量数据库类型执行不同的验证逻辑
  458. try:
  459. training_data = vn.get_training_data()
  460. if training_data is not None and not training_data.empty:
  461. print(f"✓ 已从{vector_db_type.upper()}中检索到 {len(training_data)} 条训练数据进行验证。")
  462. # 显示训练数据类型统计
  463. if 'training_data_type' in training_data.columns:
  464. type_counts = training_data['training_data_type'].value_counts()
  465. print("训练数据类型统计:")
  466. for data_type, count in type_counts.items():
  467. print(f" {data_type}: {count} 条")
  468. elif training_data is not None and training_data.empty:
  469. print(f"⚠ 在{vector_db_type.upper()}中未找到任何训练数据。")
  470. else: # training_data is None
  471. print(f"⚠ 无法从Vanna获取训练数据 (可能返回了None)。请检查{vector_db_type.upper()}连接和Vanna实现。")
  472. except Exception as e:
  473. print(f"✗ 验证训练数据失败: {e}")
  474. print(f"请检查{vector_db_type.upper()}连接和表结构。")
  475. else:
  476. print("\n===== 未能找到或处理任何训练文件,训练过程终止 =====")
  477. # 输出embedding模型信息
  478. print("\n===== Embedding模型信息 =====")
  479. try:
  480. from common.utils import get_current_embedding_config, get_current_model_info
  481. embedding_config = get_current_embedding_config()
  482. model_info = get_current_model_info()
  483. print(f"模型类型: {model_info['embedding_type']}")
  484. print(f"模型名称: {model_info['embedding_model']}")
  485. print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
  486. if 'base_url' in embedding_config:
  487. print(f"API服务: {embedding_config['base_url']}")
  488. except ImportError as e:
  489. print(f"警告: 无法导入配置工具函数: {e}")
  490. # 回退到旧的配置访问方式
  491. embedding_config = getattr(app_config, 'API_EMBEDDING_CONFIG', {})
  492. print(f"模型名称: {embedding_config.get('model_name', '未知')}")
  493. print(f"向量维度: {embedding_config.get('embedding_dimension', '未知')}")
  494. print(f"API服务: {embedding_config.get('base_url', '未知')}")
  495. # 根据配置显示向量数据库信息
  496. if vector_db_type == "chromadb":
  497. chroma_display_path = os.path.abspath(project_root)
  498. print(f"向量数据库: ChromaDB ({chroma_display_path})")
  499. elif vector_db_type == "pgvector":
  500. pg_config = app_config.PGVECTOR_CONFIG
  501. print(f"向量数据库: PgVector ({pg_config['host']}:{pg_config['port']}/{pg_config['dbname']})")
  502. print("===== 训练流程完成 =====\n")
  503. if __name__ == "__main__":
  504. main()