本方案基于对 unified_api.py
文件的深入分析,提出以下两个API改造与新增需求:
/api/v0/training_data/update
,支持先删除再插入的训练数据更新操作/api/v0/training_data/upload
,支持上传多种格式文件并自动解析入库/api/v0/training_data/combine
,合并相同的 langchain_pg_embedding.document 记录unified_api.py
目前实现了4个训练数据相关的API端点:
GET /api/v0/training_data/stats
(lines 1205-1234) - 基础统计信息POST /api/v0/training_data/query
(lines 1236-1294) - 分页查询POST /api/v0/training_data/create
(lines 1296-1404) - 批量创建(最多50条)POST /api/v0/training_data/delete
(lines 1406-1471) - 批量删除(最多50条)表结构关系:
-- langchain_pg_collection (集合表)
CREATE TABLE langchain_pg_collection (
uuid uuid PRIMARY KEY,
name varchar NOT NULL UNIQUE,
cmetadata json
);
-- langchain_pg_embedding (训练数据表)
CREATE TABLE langchain_pg_embedding (
id varchar PRIMARY KEY,
collection_id uuid REFERENCES langchain_pg_collection(uuid) ON DELETE CASCADE,
embedding vector,
document varchar,
cmetadata jsonb
);
集合名称映射:
sql
集合 → SQL查询和问答对ddl
集合 → 数据库表结构定义documentation
集合 → 表和字段文档error_sql
集合 → 错误SQL示例/api/v0/training_data/update
(新增的API)功能描述:
50bb6b17
-d5be-48ab-8125-58ab8bb076
-sql"WEB UI使用流程:
API规格:
请求参数:
{
"id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
"training_data_type": "sql",
"question": "如何查询所有用户?",
"sql": "SELECT * FROM users;"
}
参数说明:
id
: 要删除的原始记录ID(必需)training_data_type
: 训练数据类型,支持 "sql", "ddl", "documentation", "error_sql"sql
类型:question
(可选), sql
(必需)error_sql
类型:question
(必需), sql
(必需)documentation
类型:content
(必需)ddl
类型:ddl
(必需)实现代码框架:
@app.route('/api/v0/training_data/update', methods=['POST'])
def training_data_update():
"""更新训练数据API - 支持单条更新,采用先删除后插入策略"""
try:
req = request.get_json(force=True)
# 1. 参数验证
original_id = req.get('id')
if not original_id:
return jsonify(bad_request_response(
response_text="缺少必需参数:id",
missing_params=["id"]
)), 400
training_type = req.get('training_data_type')
if not training_type:
return jsonify(bad_request_response(
response_text="缺少必需参数:training_data_type",
missing_params=["training_data_type"]
)), 400
# 2. 先删除原始记录
try:
success = vn.remove_training_data(original_id)
if not success:
return jsonify(bad_request_response(
response_text=f"原始记录 {original_id} 不存在或删除失败"
)), 400
except Exception as e:
return jsonify(internal_error_response(
response_text=f"删除原始记录失败: {str(e)}"
)), 500
# 3. 根据类型验证和准备新数据
try:
if training_type == 'sql':
sql = req.get('sql')
if not sql:
return jsonify(bad_request_response(
response_text="SQL字段是必需的",
missing_params=["sql"]
)), 400
# SQL语法检查
is_valid, error_msg = validate_sql_syntax(sql)
if not is_valid:
return jsonify(bad_request_response(
response_text=f"SQL语法错误: {error_msg}"
)), 400
question = req.get('question')
if question:
training_id = vn.train(question=question, sql=sql)
else:
training_id = vn.train(sql=sql)
elif training_type == 'error_sql':
question = req.get('question')
sql = req.get('sql')
if not question or not sql:
return jsonify(bad_request_response(
response_text="question和sql字段都是必需的",
missing_params=["question", "sql"]
)), 400
training_id = vn.train_error_sql(question=question, sql=sql)
elif training_type == 'documentation':
content = req.get('content')
if not content:
return jsonify(bad_request_response(
response_text="content字段是必需的",
missing_params=["content"]
)), 400
training_id = vn.train(documentation=content)
elif training_type == 'ddl':
ddl = req.get('ddl')
if not ddl:
return jsonify(bad_request_response(
response_text="ddl字段是必需的",
missing_params=["ddl"]
)), 400
training_id = vn.train(ddl=ddl)
else:
return jsonify(bad_request_response(
response_text=f"不支持的训练数据类型: {training_type}"
)), 400
except Exception as e:
return jsonify(internal_error_response(
response_text=f"创建新训练数据失败: {str(e)}"
)), 500
# 4. 获取更新后的总记录数
current_total = get_total_training_count()
return jsonify(success_response(
response_text="训练数据更新成功",
data={
"original_id": original_id,
"new_training_id": training_id,
"type": training_type,
"current_total_count": current_total
}
))
except Exception as e:
logger.error(f"training_data_update执行失败: {str(e)}")
return jsonify(internal_error_response(
response_text="更新训练数据失败,请稍后重试"
)), 500
返回格式:
成功响应:
{
"code": 200,
"success": true,
"message": "操作成功",
"data": {
"response": "训练数据更新成功",
"original_id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
"new_training_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890-sql",
"type": "sql",
"current_total_count": 1250
}
}
错误响应示例:
{
"code": 400,
"success": false,
"message": "请求参数错误",
"data": {
"response": "原始记录 6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql 不存在或删除失败",
"error_type": "INVALID_PARAMS",
"timestamp": "2025-01-15T10:30:00.000Z",
"can_retry": false
}
}
错误响应示例:
{
"code": 400,
"success": false,
"message": "请求参数错误",
"data": {
"response": "缺少必需参数:data",
"error_type": "INVALID_PARAMS",
"timestamp": "2025-01-15T10:30:00.000Z",
"can_retry": false
}
}
安全性和性能考虑:
事务安全:
数据一致性:
错误处理:
单条处理:
性能优化:
使用建议:
前端集成:
// 更新训练数据的示例
const updateTrainingData = async (originalId, newData) => {
const response = await fetch('/api/v0/training_data/update', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
id: originalId,
training_data_type: 'sql',
question: newData.question,
sql: newData.sql
})
});
const result = await response.json();
if (result.success) {
// 更新成功,使用新的training_id更新界面
const newId = result.data.new_training_id;
updateUIWithNewId(newId);
}
};
success
字段显示详细的错误信息给用户
最佳实践:
更新前验证原始记录是否存在
更新后验证新记录是否创建成功
使用返回的新ID更新前端界面
/api/v0/training_data/upload
新增(重点设计)功能描述:
file_type
参数指定文件类型,不再依赖文件扩展名支持的文件类型:
API 支持以下5种文件类型,通过 file_type
参数指定:
file_type | 支持扩展名 | 内容格式验证 | 复用函数 | 解析方式 |
---|---|---|---|---|
ddl |
ddl/sql/txt/无扩展名 | 必须包含 CREATE 语句(大小写不敏感) |
train_ddl_statements() |
建表语句,以";"为分隔符 |
markdown |
md/markdown | 必须包含 ## 标题 |
train_documentation_blocks() |
Markdown以"##"为单位分割 |
sql_pair_json |
json/txt/无扩展名 | 必须是有效JSON格式,包含question/sql字段 | train_json_question_sql_pairs() |
JSON问答对训练 |
sql_pair |
sql/txt/无扩展名 | 必须包含 Question: /SQL: 格式 |
train_formatted_question_sql_pairs() |
格式化问答对训练 |
sql |
sql/txt/无扩展名 | 必须包含 ; 分隔符 |
train_sql_examples() |
SQL示例,以";"为分隔符 |
文件类型详细说明:
1. DDL文件 (file_type: ddl
):
建表语句,必须包含CREATE语句,以";"为分隔符。
-- 建表语句
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(255) UNIQUE
);
CREATE TABLE orders (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id),
amount DECIMAL(10,2)
);
2. Markdown文档 (file_type: markdown
):
必须包含"##"标题,以"##"为单位进行分割:
## users表
用户基本信息表,存储系统用户数据。
字段说明
- id: 用户唯一标识
- name: 用户姓名
- email: 电子邮箱
## orders表
订单信息表,记录用户购买记录。
3. JSON问答对 (file_type: sql_pair_json
):
必须是有效JSON格式,"question"/"sql" 大小写不敏感。
[
{
"question": "如何查询活跃用户数量?",
"sql": "SELECT COUNT(*) FROM users WHERE active = true;"
},
{
"question": "如何统计每个分类的产品数量?",
"sql": "SELECT category, COUNT(*) FROM products GROUP BY category;"
}
]
4. 格式化问答对 (file_type: sql_pair
):
必须包含Question/SQL格式,大小写不敏感,问答对之间用空行分隔。
Question: 如何查询观影人数超过 10000 的热门电影?
SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
Question: 如何统计每个分类的产品数量?
SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
5. SQL示例 (file_type: sql
):
纯SQL语句,必须包含";"分隔符,系统会自动生成对应问题。
SELECT COUNT(*) FROM users WHERE active = true;
SELECT category, COUNT(*) FROM products GROUP BY category;
API规格:
请求参数:
file
: 上传的文件(multipart/form-data)file_type
: 文件类型(必需),可选值:ddl
、markdown
、sql_pair_json
、sql_pair
、sql
file_type
动态确定(详见上表)文件扩展名验证规则:
系统根据 file_type
参数动态确定允许的文件扩展名:
.ddl
、.sql
、.txt
扩展名或无扩展名.md
、.markdown
扩展名(不支持无扩展名).json
、.txt
扩展名或无扩展名.sql
、.txt
扩展名或无扩展名.sql
、.txt
扩展名或无扩展名扩展名验证错误示例:
文件类型 ddl 不支持的文件扩展名:pdf,支持的扩展名:ddl, sql, txt 或无扩展名
文件类型 markdown 不支持的文件扩展名:txt,支持的扩展名:md, markdown
内容格式验证规则:
DDL类型验证:
CREATE
关键字(大小写不敏感)"文件内容不符合DDL格式,必须包含CREATE语句"
Markdown类型验证:
##
标题标记"文件内容不符合Markdown格式,必须包含##标题"
SQL_PAIR_JSON类型验证:
"文件内容不符合JSON问答对格式"
SQL_PAIR类型验证:
Question:
和 SQL:
标记(大小写不敏感)"文件内容不符合问答对格式,必须包含Question:和SQL:"
SQL类型验证:
;
分隔符"文件内容不符合SQL格式,必须包含;分隔符"
实现代码框架:
import json
import re
from werkzeug.utils import secure_filename
from data_pipeline.trainer.run_training import (
train_ddl_statements,
train_documentation_blocks,
train_json_question_sql_pairs,
train_formatted_question_sql_pairs,
train_sql_examples
)
from common.result import success_response, bad_request_response, internal_error_response
def get_allowed_extensions(file_type: str) -> list:
"""根据文件类型返回允许的扩展名"""
type_specific_extensions = {
'ddl': ['ddl', 'sql', 'txt', ''], # 支持无扩展名
'markdown': ['md', 'markdown'], # 不支持无扩展名
'sql_pair_json': ['json', 'txt', ''], # 支持无扩展名
'sql_pair': ['sql', 'txt', ''], # 支持无扩展名
'sql': ['sql', 'txt', ''] # 支持无扩展名
}
return type_specific_extensions.get(file_type, [])
@app.route('/api/v0/training_data/upload', methods=['POST'])
def upload_training_data():
"""上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
try:
# 1. 参数验证
if 'file' not in request.files:
return jsonify(bad_request_response("未提供文件"))
file = request.files['file']
if file.filename == '':
return jsonify(bad_request_response("未选择文件"))
# 获取file_type参数
file_type = request.form.get('file_type')
if not file_type:
return jsonify(bad_request_response("缺少必需参数:file_type"))
# 验证file_type参数
valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
if file_type not in valid_file_types:
return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
# 2. 文件大小验证 (500KB)
file.seek(0, 2)
file_size = file.tell()
file.seek(0)
if file_size > 500 * 1024: # 500KB
return jsonify(bad_request_response("文件大小不能超过500KB"))
# 3. 验证文件扩展名(基于file_type)
filename = secure_filename(file.filename)
allowed_extensions = get_allowed_extensions(file_type)
file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
if file_ext not in allowed_extensions:
# 构建友好的错误信息
non_empty_extensions = [ext for ext in allowed_extensions if ext]
if '' in allowed_extensions:
ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
else:
ext_message = ', '.join(non_empty_extensions)
return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
# 4. 读取文件内容并验证格式
file.seek(0)
content = file.read().decode('utf-8')
# 格式验证
validation_result = validate_file_content(content, file_type)
if not validation_result['valid']:
return jsonify(bad_request_response(validation_result['error']))
# 5. 创建临时文件(复用现有函数需要文件路径)
temp_file_path = None
try:
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
tmp_file.write(content)
temp_file_path = tmp_file.name
# 6. 根据文件类型调用现有的训练函数
if file_type == 'ddl':
train_ddl_statements(temp_file_path)
elif file_type == 'markdown':
train_documentation_blocks(temp_file_path)
elif file_type == 'sql_pair_json':
train_json_question_sql_pairs(temp_file_path)
elif file_type == 'sql_pair':
train_formatted_question_sql_pairs(temp_file_path)
elif file_type == 'sql':
train_sql_examples(temp_file_path)
return jsonify(success_response(
response_text=f"文件上传并训练成功:{filename}",
data={
"filename": filename,
"file_type": file_type,
"file_size": file_size,
"status": "completed"
}
))
except Exception as e:
logger.error(f"训练失败: {str(e)}")
return jsonify(internal_error_response(f"训练失败: {str(e)}"))
finally:
# 清理临时文件
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception as e:
logger.warning(f"清理临时文件失败: {str(e)}")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
def validate_file_content(content: str, file_type: str) -> dict:
"""验证文件内容格式"""
try:
if file_type == 'ddl':
# 检查是否包含CREATE语句
if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
elif file_type == 'markdown':
# 检查是否包含##标题
if '##' not in content:
return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
elif file_type == 'sql_pair_json':
# 检查是否为有效JSON
try:
data = json.loads(content)
if not isinstance(data, list) or not data:
return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
# 检查是否包含question和sql字段
for item in data:
if not isinstance(item, dict):
return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
has_question = any(key.lower() == 'question' for key in item.keys())
has_sql = any(key.lower() == 'sql' for key in item.keys())
if not has_question or not has_sql:
return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
except json.JSONDecodeError:
return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
elif file_type == 'sql_pair':
# 检查是否包含Question:和SQL:
if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
elif file_type == 'sql':
# 检查是否包含;分隔符
if ';' not in content:
return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
return {'valid': True}
except Exception as e:
return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}
返回格式:
成功响应:
{
"code": 200,
"success": true,
"message": "操作成功",
"data": {
"response": "文件上传并训练成功:training_data.sql",
"filename": "training_data.sql",
"file_type": "sql",
"file_size": 12345,
"status": "completed"
}
}
错误响应:
{
"code": 400,
"success": false,
"message": "请求参数错误",
"data": {
"response": "文件内容不符合SQL格式,必须包含;分隔符",
"error_type": "INVALID_PARAMS",
"timestamp": "2025-01-15T10:30:00.000Z",
"can_retry": false
}
}
安全性和性能考虑:
文件处理优化:
类型验证机制:
file_type
参数动态确定允许的扩展名复用成熟功能:
data_pipeline.trainer.run_training
中的成熟函数train_ddl_statements
、train_documentation_blocks
等处理逻辑train_json_question_sql_pairs
和 train_formatted_question_sql_pairs
处理问答对错误处理:
性能优化:
使用建议:
文件准备:
文件类型选择:
ddl
:用于训练表结构信息,支持 .ddl/.sql/.txt 扩展名或无扩展名,必须包含CREATE语句markdown
:用于训练表和字段说明,仅支持 .md/.markdown 扩展名,必须包含##标题sql_pair_json
:用于训练结构化问答数据,支持 .json/.txt 扩展名或无扩展名,必须是有效JSON格式sql_pair
:用于训练文本格式问答数据,支持 .sql/.txt 扩展名或无扩展名,必须包含Question:/SQL:格式sql
:用于训练SQL示例,支持 .sql/.txt 扩展名或无扩展名,必须包含;分隔符,系统会自动生成问题最佳实践:
file_type
参数Postman测试示例:
POST /api/v0/training_data/upload
Content-Type: multipart/form-data
form-data:
- file: [选择文件]
- file_type: sql_pair
示例文件内容:
Question: 如何查询观影人数超过 10000 的热门电影?
SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
Question: 如何统计每个分类的产品数量?
SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
/api/v0/training_data/combine
(新增的API)功能描述:
技术分析:
训练数据写入机制中,不同类型的数据在 document 字段中存储不同格式的内容:
{"question": "...", "sql": "..."}
合并策略采用 document 字段精确匹配,因为:
API规格:
请求参数:
{
"collection_names": ["sql", "ddl", "documentation", "error_sql"],
"dry_run": true,
"keep_strategy": "first"
}
参数说明:
collection_names
: 要处理的集合名称数组,支持 ["sql", "ddl", "documentation", "error_sql"]dry_run
: 是否为预览模式,默认 true(安全)keep_strategy
: 保留策略
"first"
: 保留第一条记录(按ID排序,推荐默认)"last"
: 保留最后一条记录(按ID排序)"by_metadata_time"
: 按 cmetadata.createdat 时间排序保留最新记录实现代码框架:
@app.route('/api/v0/training_data/combine', methods=['POST'])
def combine_training_data():
try:
# 1. 参数验证
data = request.get_json()
if not data:
return jsonify(bad_request_response("请求体不能为空"))
collection_names = data.get('collection_names', [])
if not collection_names or not isinstance(collection_names, list):
return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
# 验证集合名称
valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
invalid_collections = [name for name in collection_names if name not in valid_collections]
if invalid_collections:
return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
dry_run = data.get('dry_run', True)
keep_strategy = data.get('keep_strategy', 'first')
if keep_strategy not in ['first', 'last', 'by_metadata_time']:
return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
# 2. 获取数据库连接
conn = get_db_connection()
cursor = conn.cursor()
# 3. 查找重复记录
duplicate_groups = []
total_before = 0
total_duplicates = 0
collections_stats = {}
for collection_name in collection_names:
# 获取集合ID
cursor.execute(
"SELECT uuid FROM langchain_pg_collection WHERE name = %s",
(collection_name,)
)
collection_result = cursor.fetchone()
if not collection_result:
continue
collection_id = collection_result[0]
# 统计该集合的记录数
cursor.execute(
"SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
(collection_id,)
)
collection_before = cursor.fetchone()[0]
total_before += collection_before
# 查找重复记录
cursor.execute("""
SELECT document, COUNT(*) as duplicate_count,
array_agg(id ORDER BY %s) as record_ids
FROM langchain_pg_embedding
WHERE collection_id = %s
GROUP BY document
HAVING COUNT(*) > 1
""" % (
"id" if keep_strategy in ['first', 'last'] else "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id",
"%s"
), (collection_id,))
collection_duplicates = 0
for row in cursor.fetchall():
document_content, duplicate_count, record_ids = row
collection_duplicates += duplicate_count - 1 # 减去要保留的一条
# 根据保留策略选择要保留的记录
if keep_strategy == 'first':
keep_id = record_ids[0]
remove_ids = record_ids[1:]
elif keep_strategy == 'last':
keep_id = record_ids[-1]
remove_ids = record_ids[:-1]
else: # by_metadata_time
keep_id = record_ids[0] # 已经按时间排序
remove_ids = record_ids[1:]
duplicate_groups.append({
"collection_name": collection_name,
"document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
"duplicate_count": duplicate_count,
"kept_record_id" if not dry_run else "records_to_keep": keep_id,
"removed_record_ids" if not dry_run else "records_to_remove": remove_ids
})
total_duplicates += collection_duplicates
collections_stats[collection_name] = {
"before": collection_before,
"after": collection_before - collection_duplicates,
"duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
}
# 4. 执行合并操作(如果不是dry_run)
if not dry_run:
try:
conn.autocommit = False # 开始事务
for group in duplicate_groups:
remove_ids = group["removed_record_ids"]
if remove_ids:
cursor.execute(
"DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
(remove_ids,)
)
conn.commit()
except Exception as e:
conn.rollback()
return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
finally:
conn.autocommit = True
# 5. 构建响应
total_after = total_before - total_duplicates
summary = {
"total_records_before": total_before,
"total_records_after": total_after,
"duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
"collections_stats": collections_stats
}
if dry_run:
response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
data_key = "duplicate_groups"
else:
response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
data_key = "merged_groups"
return jsonify(success_response(
response_text=response_text,
data={
"dry_run": dry_run,
"collections_processed": collection_names,
"summary": summary,
data_key: duplicate_groups
}
))
except Exception as e:
return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
finally:
if 'cursor' in locals():
cursor.close()
if 'conn' in locals():
conn.close()
返回格式(dry_run: true):
{
"code": 200,
"success": true,
"message": "操作成功",
"data": {
"response": "发现 50 条重复记录,预计删除后将从 1000 条减少到 950 条记录",
"dry_run": true,
"collections_processed": ["sql", "ddl"],
"summary": {
"total_records_before": 1000,
"total_records_after": 950,
"duplicates_to_remove": 50,
"collections_stats": {
"sql": {
"before": 600,
"after": 580,
"duplicates_to_remove": 20
},
"ddl": {
"before": 400,
"after": 370,
"duplicates_to_remove": 30
}
}
},
"duplicate_groups": [
{
"collection_name": "sql",
"document_content": "SELECT * FROM users WHERE id = ?",
"duplicate_count": 3,
"records_to_keep": "uuid1-sql",
"records_to_remove": ["uuid2-sql", "uuid3-sql"]
}
]
}
}
返回格式(dry_run: false):
{
"code": 200,
"success": true,
"message": "操作成功",
"data": {
"response": "成功合并重复记录,删除了 50 条重复记录,从 1000 条减少到 950 条记录",
"dry_run": false,
"collections_processed": ["sql", "ddl"],
"summary": {
"total_records_before": 1000,
"total_records_after": 950,
"duplicates_removed": 50,
"collections_stats": {
"sql": {
"before": 600,
"after": 580,
"duplicates_removed": 20
},
"ddl": {
"before": 400,
"after": 370,
"duplicates_removed": 30
}
}
},
"merged_groups": [
{
"collection_name": "sql",
"document_content": "SELECT * FROM users WHERE id = ?",
"duplicate_count": 3,
"kept_record_id": "uuid1-sql",
"removed_record_ids": ["uuid2-sql", "uuid3-sql"]
}
]
}
}
错误响应示例:
{
"code": 400,
"success": false,
"message": "请求参数错误",
"data": {
"response": "collection_names 参数必须是非空数组",
"error_type": "INVALID_PARAMS",
"timestamp": "2025-01-15T10:30:00.000Z",
"can_retry": false
}
}
安全性和性能考虑:
事务安全:
预览模式:
保留策略:
性能优化:
错误处理:
使用建议:
推荐流程: ```bash
POST /api/v0/training_data/combine { "collection_names": ["sql", "ddl"], "dry_run": true, "keep_strategy": "first" }
# 2. 确认无误后执行实际合并 POST /api/v0/training_data/combine {
"collection_names": ["sql", "ddl"],
"dry_run": false,
"keep_strategy": "first"
}
2. **保留策略选择**:
- 一般情况:使用 `"first"` 策略
- 需要保留最新数据:使用 `"by_metadata_time"` 策略
- 测试环境:可以使用 `"last"` 策略
3. **批量处理**:
- 大量数据时建议分批处理不同集合
- 可以先处理较小的集合验证效果
- 重要数据建议先备份
---
## 4. 补充建议
### 4.1 错误处理增强
建议在所有API中添加更详细的错误处理:
```python
# 通用错误处理装饰器
def handle_api_errors(f):
@wraps(f)
def decorated_function(*args, **kwargs):
try:
return f(*args, **kwargs)
except ValueError as e:
return jsonify(bad_request_response(str(e)))
except Exception as e:
logger.error(f"API错误: {str(e)}", exc_info=True)
return jsonify(internal_error_response(f"系统错误: {str(e)}"))
return decorated_function
对于大量数据的处理,建议引入任务队列:
# 异步任务支持
@app.route('/api/v0/training_data/upload_async', methods=['POST'])
def upload_training_data_async():
# 创建后台任务
task_id = str(uuid.uuid4())
# 将任务放入队列
# 返回任务ID供客户端查询进度
return jsonify(success_response(
response_text="文件上传任务已创建",
data={"task_id": task_id}
))
@app.route('/api/v0/training_data/task_status/<task_id>', methods=['GET'])
def get_task_status(task_id):
# 查询任务状态
pass
建议添加更严格的数据验证:
# SQL语法验证
def validate_sql(sql):
try:
import sqlparse
parsed = sqlparse.parse(sql)
if not parsed:
raise ValueError("SQL语法错误")
return True
except Exception as e:
raise ValueError(f"SQL验证失败: {str(e)}")
# 数据类型验证
def validate_training_data(data_type, content):
if data_type == 'sql':
return validate_sql(content)
elif data_type == 'ddl':
return validate_ddl(content)
# 其他验证逻辑
本改造方案通过充分复用现有代码,提供了更加完整和实用的解决方案:
data_pipeline/trainer
模块的成熟函数,支持多种文件格式的自动解析和导入主要优势:
避免重复造轮子:
run_training.py
中的成熟解析函数提高代码质量:
性能优化:
特别是upload API的设计,通过复用现有的trainer模块,不仅避免了重复开发,还确保了与现有系统的完美兼容性,可以满足生产环境的需求。