|
@@ -325,28 +325,30 @@ def training_data_update():
|
|
|
|
|
|
**功能描述:**
|
|
|
|
|
|
-- 支持上传多种格式的训练数据文件(.ddl/.md/.json/.sql)
|
|
|
-- 直接内存处理,无需临时文件存储
|
|
|
+- 支持上传多种格式的训练数据文件(txt/md/sql/json 或无扩展名)
|
|
|
+- 通过 `file_type` 参数指定文件类型,不再依赖文件扩展名
|
|
|
- 文件大小限制500KB(约包含100-500条记录)
|
|
|
-- 自动识别文件类型并解析内容
|
|
|
+- 对文件内容进行格式验证,确保符合指定类型要求
|
|
|
- 复用现有trainer模块的成熟功能
|
|
|
- 提供详细的处理结果反馈
|
|
|
|
|
|
**支持的文件类型:**
|
|
|
|
|
|
-基于现有 `data_pipeline/trainer` 模块的成熟功能,API 支持以下文件类型的自动识别和处理:
|
|
|
+API 支持以下5种文件类型,通过 `file_type` 参数指定:
|
|
|
|
|
|
-| 文件类型 | 扩展名 | 文件命名规则 | 复用函数 | 解析方式 |
|
|
|
-|---------|--------|-------------|---------|-----------|
|
|
|
-| ddl | `.ddl` | 任意名称 | `train_ddl_file()` → `read_file_by_delimiter()` | 建表语句,以“;”为分隔符 |
|
|
|
-| documentation | `.md` | 任意名称 | `train_documentation_file()` → `read_markdown_file_by_sections()` | Markdown以"##"为单位分割,普通文档以"---"分隔 |
|
|
|
-| question_sql_pairs | `.json` | 必须包含 `_pair`或 `_pairs` | `train_json_pairs_file()` → `train_json_question_sql_pairs()` | 问答对训练(JSON格式) |
|
|
|
-| formatted_pairs | `.sql` | 必须包含 `_pair`或 `_pairs` | `train_formatted_pairs_file()` → `train_formatted_question_sql_pairs()` | 问答对训练(文本格式),每对Question/SQL之间使用空行分隔 |
|
|
|
-| sql_examples | `.sql` | 任意名称(不包含`_pair`) | `train_sql_file()` → `read_file_by_delimiter()` | SQL示例,以";"为分隔符 |
|
|
|
+| file_type | 支持扩展名 | 内容格式验证 | 复用函数 | 解析方式 |
|
|
|
+|-----------|------------|-------------|---------|-----------|
|
|
|
+| `ddl` | **ddl**/sql/txt/无扩展名 | 必须包含 `CREATE` 语句(大小写不敏感) | `train_ddl_statements()` | 建表语句,以";"为分隔符 |
|
|
|
+| `markdown` | **md/markdown** | 必须包含 `##` 标题 | `train_documentation_blocks()` | Markdown以"##"为单位分割 |
|
|
|
+| `sql_pair_json` | **json**/txt/无扩展名 | 必须是有效JSON格式,包含question/sql字段 | `train_json_question_sql_pairs()` | JSON问答对训练 |
|
|
|
+| `sql_pair` | **sql**/txt/无扩展名 | 必须包含 `Question:`/`SQL:` 格式 | `train_formatted_question_sql_pairs()` | 格式化问答对训练 |
|
|
|
+| `sql` | **sql**/txt/无扩展名 | 必须包含 `;` 分隔符 | `train_sql_examples()` | SQL示例,以";"为分隔符 |
|
|
|
|
|
|
-**1. DDL文件 (*.ddl):**
|
|
|
+**文件类型详细说明:**
|
|
|
|
|
|
-建表语句,解析时以“;”为分隔符。
|
|
|
+**1. DDL文件 (file_type: `ddl`):**
|
|
|
+
|
|
|
+建表语句,必须包含CREATE语句,以";"为分隔符。
|
|
|
|
|
|
```sql
|
|
|
-- 建表语句
|
|
@@ -363,9 +365,9 @@ CREATE TABLE orders (
|
|
|
);
|
|
|
```
|
|
|
|
|
|
-**2. 文档文件 (*.md):**
|
|
|
+**2. Markdown文档 (file_type: `markdown`):**
|
|
|
|
|
|
-以"##"为单位进行分割,每个"##"是一个表名:
|
|
|
+必须包含"##"标题,以"##"为单位进行分割:
|
|
|
|
|
|
```markdown
|
|
|
## users表
|
|
@@ -379,9 +381,9 @@ CREATE TABLE orders (
|
|
|
订单信息表,记录用户购买记录。
|
|
|
```
|
|
|
|
|
|
-**3. JSON问答对 (*_pair.json):**
|
|
|
+**3. JSON问答对 (file_type: `sql_pair_json`):**
|
|
|
|
|
|
-"question"/"sql" 大小写不敏感。
|
|
|
+必须是有效JSON格式,"question"/"sql" 大小写不敏感。
|
|
|
|
|
|
```json
|
|
|
[
|
|
@@ -396,49 +398,104 @@ CREATE TABLE orders (
|
|
|
]
|
|
|
```
|
|
|
|
|
|
-**4. 格式化问答对 (*_pair.sql):**
|
|
|
+**4. 格式化问答对 (file_type: `sql_pair`):**
|
|
|
|
|
|
-注意,Question和SQL,大小写不敏感,但是两个问答对之间要以空行作为分隔。
|
|
|
+必须包含Question/SQL格式,大小写不敏感,问答对之间用空行分隔。
|
|
|
|
|
|
```
|
|
|
-Question: 如何查询活跃用户数量?
|
|
|
-SQL:SELECT COUNT(*) FROM users WHERE active = true;
|
|
|
+Question: 如何查询观影人数超过 10000 的热门电影?
|
|
|
+SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
|
|
|
|
|
|
Question: 如何统计每个分类的产品数量?
|
|
|
-SQL:SELECT category, COUNT(*) FROM products GROUP BY category;
|
|
|
+SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
|
|
|
+```
|
|
|
+
|
|
|
+**5. SQL示例 (file_type: `sql`):**
|
|
|
+
|
|
|
+纯SQL语句,必须包含";"分隔符,系统会自动生成对应问题。
|
|
|
+
|
|
|
+```sql
|
|
|
+SELECT COUNT(*) FROM users WHERE active = true;
|
|
|
+SELECT category, COUNT(*) FROM products GROUP BY category;
|
|
|
```
|
|
|
|
|
|
**API规格:**
|
|
|
|
|
|
**请求参数:**
|
|
|
- `file`: 上传的文件(multipart/form-data)
|
|
|
+- `file_type`: 文件类型(必需),可选值:`ddl`、`markdown`、`sql_pair_json`、`sql_pair`、`sql`
|
|
|
- 文件大小限制:最大500KB
|
|
|
- 文件编码:UTF-8
|
|
|
-- 自动识别文件类型:基于文件扩展名和命名规则
|
|
|
+- 支持的文件扩展名:根据 `file_type` 动态确定(详见上表)
|
|
|
+
|
|
|
+**文件扩展名验证规则:**
|
|
|
+
|
|
|
+系统根据 `file_type` 参数动态确定允许的文件扩展名:
|
|
|
+
|
|
|
+1. **DDL类型**:允许 `.ddl`、`.sql`、`.txt` 扩展名或无扩展名
|
|
|
+2. **Markdown类型**:仅允许 `.md`、`.markdown` 扩展名(不支持无扩展名)
|
|
|
+3. **SQL_PAIR_JSON类型**:允许 `.json`、`.txt` 扩展名或无扩展名
|
|
|
+4. **SQL_PAIR类型**:允许 `.sql`、`.txt` 扩展名或无扩展名
|
|
|
+5. **SQL类型**:允许 `.sql`、`.txt` 扩展名或无扩展名
|
|
|
+
|
|
|
+**扩展名验证错误示例:**
|
|
|
+```
|
|
|
+文件类型 ddl 不支持的文件扩展名:pdf,支持的扩展名:ddl, sql, txt 或无扩展名
|
|
|
+文件类型 markdown 不支持的文件扩展名:txt,支持的扩展名:md, markdown
|
|
|
+```
|
|
|
+
|
|
|
+**内容格式验证规则:**
|
|
|
+
|
|
|
+1. **DDL类型验证**:
|
|
|
+ - 文件内容必须包含 `CREATE` 关键字(大小写不敏感)
|
|
|
+ - 验证失败返回:`"文件内容不符合DDL格式,必须包含CREATE语句"`
|
|
|
+
|
|
|
+2. **Markdown类型验证**:
|
|
|
+ - 文件内容必须包含 `##` 标题标记
|
|
|
+ - 验证失败返回:`"文件内容不符合Markdown格式,必须包含##标题"`
|
|
|
+
|
|
|
+3. **SQL_PAIR_JSON类型验证**:
|
|
|
+ - 文件必须是有效的JSON格式
|
|
|
+ - JSON内容必须包含question和sql字段
|
|
|
+ - 验证失败返回:`"文件内容不符合JSON问答对格式"`
|
|
|
+
|
|
|
+4. **SQL_PAIR类型验证**:
|
|
|
+ - 文件内容必须包含 `Question:` 和 `SQL:` 标记(大小写不敏感)
|
|
|
+ - 验证失败返回:`"文件内容不符合问答对格式,必须包含Question:和SQL:"`
|
|
|
+
|
|
|
+5. **SQL类型验证**:
|
|
|
+ - 文件内容必须包含 `;` 分隔符
|
|
|
+ - 验证失败返回:`"文件内容不符合SQL格式,必须包含;分隔符"`
|
|
|
|
|
|
**实现代码框架:**
|
|
|
```python
|
|
|
+import json
|
|
|
+import re
|
|
|
from werkzeug.utils import secure_filename
|
|
|
from data_pipeline.trainer.run_training import (
|
|
|
- read_file_by_delimiter,
|
|
|
- read_markdown_file_by_sections,
|
|
|
+ train_ddl_statements,
|
|
|
+ train_documentation_blocks,
|
|
|
train_json_question_sql_pairs,
|
|
|
- train_formatted_question_sql_pairs
|
|
|
-)
|
|
|
-from data_pipeline.trainer.vanna_trainer import (
|
|
|
- train_ddl, train_documentation, train_question_sql_pair,
|
|
|
- train_sql_example, flush_training
|
|
|
+ train_formatted_question_sql_pairs,
|
|
|
+ train_sql_examples
|
|
|
)
|
|
|
from common.result import success_response, bad_request_response, internal_error_response
|
|
|
-import json
|
|
|
-import tempfile
|
|
|
-import os
|
|
|
-import logging
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
+def get_allowed_extensions(file_type: str) -> list:
|
|
|
+ """根据文件类型返回允许的扩展名"""
|
|
|
+ type_specific_extensions = {
|
|
|
+ 'ddl': ['ddl', 'sql', 'txt', ''], # 支持无扩展名
|
|
|
+ 'markdown': ['md', 'markdown'], # 不支持无扩展名
|
|
|
+ 'sql_pair_json': ['json', 'txt', ''], # 支持无扩展名
|
|
|
+ 'sql_pair': ['sql', 'txt', ''], # 支持无扩展名
|
|
|
+ 'sql': ['sql', 'txt', ''] # 支持无扩展名
|
|
|
+ }
|
|
|
+
|
|
|
+ return type_specific_extensions.get(file_type, [])
|
|
|
|
|
|
@app.route('/api/v0/training_data/upload', methods=['POST'])
|
|
|
def upload_training_data():
|
|
|
+ """上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
|
|
|
try:
|
|
|
# 1. 参数验证
|
|
|
if 'file' not in request.files:
|
|
@@ -448,6 +505,16 @@ def upload_training_data():
|
|
|
if file.filename == '':
|
|
|
return jsonify(bad_request_response("未选择文件"))
|
|
|
|
|
|
+ # 获取file_type参数
|
|
|
+ file_type = request.form.get('file_type')
|
|
|
+ if not file_type:
|
|
|
+ return jsonify(bad_request_response("缺少必需参数:file_type"))
|
|
|
+
|
|
|
+ # 验证file_type参数
|
|
|
+ valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
|
|
|
+ if file_type not in valid_file_types:
|
|
|
+ return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
|
|
|
+
|
|
|
# 2. 文件大小验证 (500KB)
|
|
|
file.seek(0, 2)
|
|
|
file_size = file.tell()
|
|
@@ -456,45 +523,59 @@ def upload_training_data():
|
|
|
if file_size > 500 * 1024: # 500KB
|
|
|
return jsonify(bad_request_response("文件大小不能超过500KB"))
|
|
|
|
|
|
- # 3. 创建临时文件(复用现有函数需要文件路径)
|
|
|
+ # 3. 验证文件扩展名(基于file_type)
|
|
|
filename = secure_filename(file.filename)
|
|
|
+ allowed_extensions = get_allowed_extensions(file_type)
|
|
|
+ file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
|
|
|
+
|
|
|
+ if file_ext not in allowed_extensions:
|
|
|
+ # 构建友好的错误信息
|
|
|
+ non_empty_extensions = [ext for ext in allowed_extensions if ext]
|
|
|
+ if '' in allowed_extensions:
|
|
|
+ ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
|
|
|
+ else:
|
|
|
+ ext_message = ', '.join(non_empty_extensions)
|
|
|
+ return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
|
|
|
+
|
|
|
+ # 4. 读取文件内容并验证格式
|
|
|
+ file.seek(0)
|
|
|
+ content = file.read().decode('utf-8')
|
|
|
+
|
|
|
+ # 格式验证
|
|
|
+ validation_result = validate_file_content(content, file_type)
|
|
|
+ if not validation_result['valid']:
|
|
|
+ return jsonify(bad_request_response(validation_result['error']))
|
|
|
+
|
|
|
+ # 5. 创建临时文件(复用现有函数需要文件路径)
|
|
|
temp_file_path = None
|
|
|
|
|
|
try:
|
|
|
- with tempfile.NamedTemporaryFile(mode='w+b', delete=False, suffix=os.path.splitext(filename)[1]) as tmp_file:
|
|
|
- file.save(tmp_file.name)
|
|
|
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
|
|
|
+ tmp_file.write(content)
|
|
|
temp_file_path = tmp_file.name
|
|
|
|
|
|
- # 4. 根据文件类型调用现有的训练函数
|
|
|
- filename_lower = filename.lower()
|
|
|
- success_count = 0
|
|
|
-
|
|
|
- if filename_lower.endswith('.ddl'):
|
|
|
- success_count = train_ddl_file(temp_file_path)
|
|
|
+ # 6. 根据文件类型调用现有的训练函数
|
|
|
+ if file_type == 'ddl':
|
|
|
+ train_ddl_statements(temp_file_path)
|
|
|
|
|
|
- elif filename_lower.endswith(('.md', '.markdown')):
|
|
|
- success_count = train_documentation_file(temp_file_path)
|
|
|
+ elif file_type == 'markdown':
|
|
|
+ train_documentation_blocks(temp_file_path)
|
|
|
|
|
|
- elif filename_lower.endswith('_pair.json') or filename_lower.endswith('_pairs.json'):
|
|
|
- success_count = train_json_pairs_file(temp_file_path)
|
|
|
+ elif file_type == 'sql_pair_json':
|
|
|
+ train_json_question_sql_pairs(temp_file_path)
|
|
|
|
|
|
- elif filename_lower.endswith('_pair.sql') or filename_lower.endswith('_pairs.sql'):
|
|
|
- success_count = train_formatted_pairs_file(temp_file_path)
|
|
|
+ elif file_type == 'sql_pair':
|
|
|
+ train_formatted_question_sql_pairs(temp_file_path)
|
|
|
|
|
|
- elif filename_lower.endswith('.sql'):
|
|
|
- success_count = train_sql_file(temp_file_path)
|
|
|
- else:
|
|
|
- return jsonify(bad_request_response("不支持的文件类型"))
|
|
|
-
|
|
|
- # 5. 刷新批处理
|
|
|
- flush_training()
|
|
|
+ elif file_type == 'sql':
|
|
|
+ train_sql_examples(temp_file_path)
|
|
|
|
|
|
return jsonify(success_response(
|
|
|
- response_text=f"文件上传并训练成功:处理了 {success_count} 条记录",
|
|
|
+ response_text=f"文件上传并训练成功:{filename}",
|
|
|
data={
|
|
|
"filename": filename,
|
|
|
+ "file_type": file_type,
|
|
|
"file_size": file_size,
|
|
|
- "records_processed": success_count,
|
|
|
"status": "completed"
|
|
|
}
|
|
|
))
|
|
@@ -514,99 +595,56 @@ def upload_training_data():
|
|
|
logger.error(f"文件上传失败: {str(e)}")
|
|
|
return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
|
|
|
|
|
|
-# 复用现有训练函数的包装器
|
|
|
-def train_ddl_file(file_path: str) -> int:
|
|
|
- """复用 run_training.py 中的 DDL 训练逻辑"""
|
|
|
- success_count = 0
|
|
|
+def validate_file_content(content: str, file_type: str) -> dict:
|
|
|
+ """验证文件内容格式"""
|
|
|
try:
|
|
|
- ddl_statements = read_file_by_delimiter(file_path, ";")
|
|
|
-
|
|
|
- for ddl in ddl_statements:
|
|
|
+ if file_type == 'ddl':
|
|
|
+ # 检查是否包含CREATE语句
|
|
|
+ if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
|
|
|
+ return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
|
|
|
+
|
|
|
+ elif file_type == 'markdown':
|
|
|
+ # 检查是否包含##标题
|
|
|
+ if '##' not in content:
|
|
|
+ return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
|
|
|
+
|
|
|
+ elif file_type == 'sql_pair_json':
|
|
|
+ # 检查是否为有效JSON
|
|
|
try:
|
|
|
- train_ddl(ddl)
|
|
|
- success_count += 1
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练DDL失败: {ddl[:50]}..., 错误: {str(e)}")
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"读取DDL文件失败: {str(e)}")
|
|
|
-
|
|
|
- return success_count
|
|
|
-
|
|
|
-def train_documentation_file(file_path: str) -> int:
|
|
|
- """复用 run_training.py 中的文档训练逻辑"""
|
|
|
- success_count = 0
|
|
|
-
|
|
|
- try:
|
|
|
- # 检查是否为Markdown文件
|
|
|
- if file_path.lower().endswith(('.md', '.markdown')):
|
|
|
- sections = read_markdown_file_by_sections(file_path)
|
|
|
- for section in sections:
|
|
|
- try:
|
|
|
- train_documentation(section)
|
|
|
- success_count += 1
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练文档失败: {section[:50]}..., 错误: {str(e)}")
|
|
|
- else:
|
|
|
- # 非Markdown文件使用传统的---分隔
|
|
|
- doc_blocks = read_file_by_delimiter(file_path, "---")
|
|
|
- for doc in doc_blocks:
|
|
|
- try:
|
|
|
- train_documentation(doc)
|
|
|
- success_count += 1
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练文档失败: {doc[:50]}..., 错误: {str(e)}")
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"读取文档文件失败: {str(e)}")
|
|
|
-
|
|
|
- return success_count
|
|
|
-
|
|
|
-def train_json_pairs_file(file_path: str) -> int:
|
|
|
- """复用 run_training.py 中的 JSON 问答对训练逻辑"""
|
|
|
- try:
|
|
|
- # 直接调用现有函数
|
|
|
- train_json_question_sql_pairs(file_path)
|
|
|
-
|
|
|
- # 计算成功数量
|
|
|
- with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
- pairs = json.load(f)
|
|
|
- return len(pairs) if isinstance(pairs, list) else 1
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练JSON问答对失败: {str(e)}")
|
|
|
- return 0
|
|
|
-
|
|
|
-def train_formatted_pairs_file(file_path: str) -> int:
|
|
|
- """复用 run_training.py 中的格式化问答对训练逻辑"""
|
|
|
- try:
|
|
|
- # 直接调用现有函数
|
|
|
- train_formatted_question_sql_pairs(file_path)
|
|
|
+ data = json.loads(content)
|
|
|
+ if not isinstance(data, list) or not data:
|
|
|
+ return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
|
|
|
+
|
|
|
+ # 检查是否包含question和sql字段
|
|
|
+ for item in data:
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
|
|
|
+
|
|
|
+ has_question = any(key.lower() == 'question' for key in item.keys())
|
|
|
+ has_sql = any(key.lower() == 'sql' for key in item.keys())
|
|
|
+
|
|
|
+ if not has_question or not has_sql:
|
|
|
+ return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
|
|
|
+
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
|
|
|
+
|
|
|
+ elif file_type == 'sql_pair':
|
|
|
+ # 检查是否包含Question:和SQL:
|
|
|
+ if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
|
|
|
+ return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
|
|
|
+ if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
|
|
|
+ return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
|
|
|
+
|
|
|
+ elif file_type == 'sql':
|
|
|
+ # 检查是否包含;分隔符
|
|
|
+ if ';' not in content:
|
|
|
+ return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
|
|
|
|
|
|
- # 计算成功数量
|
|
|
- with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
- content = f.read()
|
|
|
- pairs = content.split('\n\n')
|
|
|
- return len([p for p in pairs if p.strip()])
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练格式化问答对失败: {str(e)}")
|
|
|
- return 0
|
|
|
-
|
|
|
-def train_sql_file(file_path: str) -> int:
|
|
|
- """复用 run_training.py 中的 SQL 训练逻辑"""
|
|
|
- success_count = 0
|
|
|
- try:
|
|
|
- sql_statements = read_file_by_delimiter(file_path, ";")
|
|
|
+ return {'valid': True}
|
|
|
|
|
|
- for sql in sql_statements:
|
|
|
- try:
|
|
|
- train_sql_example(sql)
|
|
|
- success_count += 1
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"训练SQL失败: {sql[:50]}..., 错误: {str(e)}")
|
|
|
except Exception as e:
|
|
|
- logger.error(f"读取SQL文件失败: {str(e)}")
|
|
|
-
|
|
|
- return success_count
|
|
|
-
|
|
|
-
|
|
|
+ return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}
|
|
|
```
|
|
|
|
|
|
**返回格式:**
|
|
@@ -618,10 +656,10 @@ def train_sql_file(file_path: str) -> int:
|
|
|
"success": true,
|
|
|
"message": "操作成功",
|
|
|
"data": {
|
|
|
- "response": "文件上传并训练成功:处理了 45 条记录",
|
|
|
+ "response": "文件上传并训练成功:training_data.sql",
|
|
|
"filename": "training_data.sql",
|
|
|
+ "file_type": "sql",
|
|
|
"file_size": 12345,
|
|
|
- "records_processed": 45,
|
|
|
"status": "completed"
|
|
|
}
|
|
|
}
|
|
@@ -634,7 +672,7 @@ def train_sql_file(file_path: str) -> int:
|
|
|
"success": false,
|
|
|
"message": "请求参数错误",
|
|
|
"data": {
|
|
|
- "response": "文件大小不能超过500KB",
|
|
|
+ "response": "文件内容不符合SQL格式,必须包含;分隔符",
|
|
|
"error_type": "INVALID_PARAMS",
|
|
|
"timestamp": "2025-01-15T10:30:00.000Z",
|
|
|
"can_retry": false
|
|
@@ -650,28 +688,28 @@ def train_sql_file(file_path: str) -> int:
|
|
|
- 自动清理临时文件,防止磁盘空间泄露
|
|
|
- 完整的错误处理和日志记录
|
|
|
|
|
|
-2. **自动类型识别**:
|
|
|
- - 基于文件扩展名和命名规则
|
|
|
- - 智能解析不同格式内容
|
|
|
+2. **类型验证机制**:
|
|
|
+ - 基于 `file_type` 参数动态确定允许的扩展名
|
|
|
+ - 对文件内容进行格式验证
|
|
|
- 大小写不敏感的格式处理
|
|
|
+ - 详细的验证错误信息,包含具体的文件类型和支持的扩展名
|
|
|
|
|
|
3. **复用成熟功能**:
|
|
|
- 直接使用 `data_pipeline.trainer.run_training` 中的成熟函数
|
|
|
- - 复用 `read_file_by_delimiter` 和 `read_markdown_file_by_sections` 解析逻辑
|
|
|
+ - 复用 `train_ddl_statements`、`train_documentation_blocks` 等处理逻辑
|
|
|
- 调用 `train_json_question_sql_pairs` 和 `train_formatted_question_sql_pairs` 处理问答对
|
|
|
- 继承批处理优化机制,保持训练逻辑的一致性
|
|
|
- 避免重复造轮子,提高代码维护性
|
|
|
|
|
|
4. **错误处理**:
|
|
|
- 详细的错误日志记录
|
|
|
- - 逐条处理失败时的容错机制
|
|
|
- - 统计处理成功和失败的记录数
|
|
|
+ - 文件格式验证失败时的明确提示
|
|
|
+ - 基于文件类型的精确扩展名验证和友好提示
|
|
|
- 临时文件清理保证不留垃圾文件
|
|
|
|
|
|
5. **性能优化**:
|
|
|
- 复用现有的高效解析函数
|
|
|
- 批处理机制提高训练效率
|
|
|
- - 自动刷新确保数据及时写入
|
|
|
- 避免重复开发,减少维护成本
|
|
|
|
|
|
**使用建议:**
|
|
@@ -679,20 +717,38 @@ def train_sql_file(file_path: str) -> int:
|
|
|
1. **文件准备**:
|
|
|
- 确保文件使用UTF-8编码
|
|
|
- 控制文件大小在500KB以内
|
|
|
- - 按照规定的格式和命名规则准备文件
|
|
|
+ - 按照指定的格式准备文件内容
|
|
|
|
|
|
2. **文件类型选择**:
|
|
|
- - DDL文件:用于训练表结构信息,复用 `train_ddl_statements` 逻辑
|
|
|
- - Markdown文档:用于训练表和字段说明,复用 `train_documentation_blocks` 逻辑
|
|
|
- - JSON问答对:用于训练结构化问答数据,复用 `train_json_question_sql_pairs` 逻辑
|
|
|
- - 格式化问答对:用于训练文本格式问答数据,复用 `train_formatted_question_sql_pairs` 逻辑
|
|
|
- - SQL文件:用于训练SQL示例,复用 `train_sql_examples` 逻辑
|
|
|
+ - `ddl`:用于训练表结构信息,支持 .ddl/.sql/.txt 扩展名或无扩展名,必须包含CREATE语句
|
|
|
+ - `markdown`:用于训练表和字段说明,仅支持 .md/.markdown 扩展名,必须包含##标题
|
|
|
+ - `sql_pair_json`:用于训练结构化问答数据,支持 .json/.txt 扩展名或无扩展名,必须是有效JSON格式
|
|
|
+ - `sql_pair`:用于训练文本格式问答数据,支持 .sql/.txt 扩展名或无扩展名,必须包含Question:/SQL:格式
|
|
|
+ - `sql`:用于训练SQL示例,支持 .sql/.txt 扩展名或无扩展名,必须包含;分隔符,系统会自动生成问题
|
|
|
|
|
|
3. **最佳实践**:
|
|
|
- 建议分批上传,每次处理一种类型的文件
|
|
|
- 上传前检查文件格式的正确性
|
|
|
- - 关注返回结果中的处理统计信息
|
|
|
- - 复用现有训练逻辑,确保数据处理的一致性
|
|
|
+ - 明确指定正确的 `file_type` 参数
|
|
|
+ - 关注返回结果中的处理状态信息
|
|
|
+
|
|
|
+**Postman测试示例:**
|
|
|
+
|
|
|
+```
|
|
|
+POST /api/v0/training_data/upload
|
|
|
+Content-Type: multipart/form-data
|
|
|
+
|
|
|
+form-data:
|
|
|
+- file: [选择文件]
|
|
|
+- file_type: sql_pair
|
|
|
+
|
|
|
+示例文件内容:
|
|
|
+Question: 如何查询观影人数超过 10000 的热门电影?
|
|
|
+SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
|
|
|
+
|
|
|
+Question: 如何统计每个分类的产品数量?
|
|
|
+SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
|
|
|
+```
|
|
|
|
|
|
---
|
|
|
|