# unified_api.py 训练数据管理API改造方案(详细版) ## 1. 需求概述 本方案基于对 `unified_api.py` 文件的深入分析,提出以下两个API改造与新增需求: - **需求一**:新增 `/api/v0/training_data/update`,支持先删除再插入的训练数据更新操作 - **需求二**:新增 `/api/v0/training_data/upload`,支持上传多种格式文件并自动解析入库 - **需求三**:新增 `/api/v0/training_data/combine`,合并相同的 langchain_pg_embedding.document 记录 --- ## 2. 当前实现分析 ### 2.1 现有API结构 `unified_api.py` 目前实现了4个训练数据相关的API端点: 1. **`GET /api/v0/training_data/stats`** (lines 1205-1234) - 基础统计信息 2. **`POST /api/v0/training_data/query`** (lines 1236-1294) - 分页查询 3. **`POST /api/v0/training_data/create`** (lines 1296-1404) - 批量创建(最多50条) 4. **`POST /api/v0/training_data/delete`** (lines 1406-1471) - 批量删除(最多50条) ### 2.2 数据库架构 **表结构关系:** ```sql -- langchain_pg_collection (集合表) CREATE TABLE langchain_pg_collection ( uuid uuid PRIMARY KEY, name varchar NOT NULL UNIQUE, cmetadata json ); -- langchain_pg_embedding (训练数据表) CREATE TABLE langchain_pg_embedding ( id varchar PRIMARY KEY, collection_id uuid REFERENCES langchain_pg_collection(uuid) ON DELETE CASCADE, embedding vector, document varchar, cmetadata jsonb ); ``` **集合名称映射:** - `sql` 集合 → SQL查询和问答对 - `ddl` 集合 → 数据库表结构定义 - `documentation` 集合 → 表和字段文档 - `error_sql` 集合 → 错误SQL示例 --- ## 3. 详细改造方案 --- ### 3.1 `/api/v0/training_data/update` (新增的API) **功能描述:** - 支持单条更新训练数据 - 采用先删除后插入的策略确保数据一致性 - 更新方式是先删除再插入,删除是以 langchain_pg_embedding.id字段为依据的 - langchain_pg_embedding.id字段是varchar类型,它是uuid+类型的格式,比如:"50bb6b17-d5be-48ab-8125-58ab8bb076e8-sql" - 返回新创建记录的ID,便于前端更新界面 **WEB UI使用流程:** 1. 用户选中一行记录,点击修改按钮 2. 前端显示编辑表单,用户修改内容 3. 提交时,根据这一行的id,删除这条记录 4. 然后调用create参数,生成一条新的记录 5. 返回的JSON中包括这条记录的新的id **API规格:** **请求参数:** ```json { "id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql", "training_data_type": "sql", "question": "如何查询所有用户?", "sql": "SELECT * FROM users;" } ``` **参数说明:** - `id`: 要删除的原始记录ID(必需) - `training_data_type`: 训练数据类型,支持 "sql", "ddl", "documentation", "error_sql" - 其他字段根据类型不同而不同: - `sql` 类型:`question`(可选), `sql`(必需) - `error_sql` 类型:`question`(必需), `sql`(必需) - `documentation` 类型:`content`(必需) - `ddl` 类型:`ddl`(必需) **实现代码框架:** ```python @app.route('/api/v0/training_data/update', methods=['POST']) def training_data_update(): """更新训练数据API - 支持单条更新,采用先删除后插入策略""" try: req = request.get_json(force=True) # 1. 参数验证 original_id = req.get('id') if not original_id: return jsonify(bad_request_response( response_text="缺少必需参数:id", missing_params=["id"] )), 400 training_type = req.get('training_data_type') if not training_type: return jsonify(bad_request_response( response_text="缺少必需参数:training_data_type", missing_params=["training_data_type"] )), 400 # 2. 先删除原始记录 try: success = vn.remove_training_data(original_id) if not success: return jsonify(bad_request_response( response_text=f"原始记录 {original_id} 不存在或删除失败" )), 400 except Exception as e: return jsonify(internal_error_response( response_text=f"删除原始记录失败: {str(e)}" )), 500 # 3. 根据类型验证和准备新数据 try: if training_type == 'sql': sql = req.get('sql') if not sql: return jsonify(bad_request_response( response_text="SQL字段是必需的", missing_params=["sql"] )), 400 # SQL语法检查 is_valid, error_msg = validate_sql_syntax(sql) if not is_valid: return jsonify(bad_request_response( response_text=f"SQL语法错误: {error_msg}" )), 400 question = req.get('question') if question: training_id = vn.train(question=question, sql=sql) else: training_id = vn.train(sql=sql) elif training_type == 'error_sql': question = req.get('question') sql = req.get('sql') if not question or not sql: return jsonify(bad_request_response( response_text="question和sql字段都是必需的", missing_params=["question", "sql"] )), 400 training_id = vn.train_error_sql(question=question, sql=sql) elif training_type == 'documentation': content = req.get('content') if not content: return jsonify(bad_request_response( response_text="content字段是必需的", missing_params=["content"] )), 400 training_id = vn.train(documentation=content) elif training_type == 'ddl': ddl = req.get('ddl') if not ddl: return jsonify(bad_request_response( response_text="ddl字段是必需的", missing_params=["ddl"] )), 400 training_id = vn.train(ddl=ddl) else: return jsonify(bad_request_response( response_text=f"不支持的训练数据类型: {training_type}" )), 400 except Exception as e: return jsonify(internal_error_response( response_text=f"创建新训练数据失败: {str(e)}" )), 500 # 4. 获取更新后的总记录数 current_total = get_total_training_count() return jsonify(success_response( response_text="训练数据更新成功", data={ "original_id": original_id, "new_training_id": training_id, "type": training_type, "current_total_count": current_total } )) except Exception as e: logger.error(f"training_data_update执行失败: {str(e)}") return jsonify(internal_error_response( response_text="更新训练数据失败,请稍后重试" )), 500 ``` **返回格式:** **成功响应:** ```json { "code": 200, "success": true, "message": "操作成功", "data": { "response": "训练数据更新成功", "original_id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql", "new_training_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890-sql", "type": "sql", "current_total_count": 1250 } } ``` **错误响应示例:** ```json { "code": 400, "success": false, "message": "请求参数错误", "data": { "response": "原始记录 6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql 不存在或删除失败", "error_type": "INVALID_PARAMS", "timestamp": "2025-01-15T10:30:00.000Z", "can_retry": false } } ``` **错误响应示例:** ```json { "code": 400, "success": false, "message": "请求参数错误", "data": { "response": "缺少必需参数:data", "error_type": "INVALID_PARAMS", "timestamp": "2025-01-15T10:30:00.000Z", "can_retry": false } } ``` **安全性和性能考虑:** 1. **事务安全**: - 删除和插入操作在同一个函数中执行 - 如果插入失败,原始记录已被删除,需要回滚或重试机制 - 建议在应用层实现重试逻辑 2. **数据一致性**: - 先删除后插入确保ID的唯一性 - 新记录会获得新的UUID,避免ID冲突 - 保持训练数据类型的正确性 3. **错误处理**: - 详细的参数验证 - 原始记录存在性检查 - SQL语法验证(针对sql类型) - 完整的异常处理和错误信息 4. **单条处理**: - 只支持单条更新,简化逻辑 - 减少并发冲突的可能性 - 更精确的错误定位和处理 5. **性能优化**: - 复用现有的训练函数 - 避免重复的数据库连接 - 合理的错误处理和日志记录 **使用建议:** 1. **前端集成**: ```javascript // 更新训练数据的示例 const updateTrainingData = async (originalId, newData) => { const response = await fetch('/api/v0/training_data/update', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ id: originalId, training_data_type: 'sql', question: newData.question, sql: newData.sql }) }); const result = await response.json(); if (result.success) { // 更新成功,使用新的training_id更新界面 const newId = result.data.new_training_id; updateUIWithNewId(newId); } }; ``` 2. **错误处理**: - 检查返回的 `success` 字段 - 处理各种错误情况(参数错误、记录不存在、SQL语法错误等) - 显示详细的错误信息给用户 3. **最佳实践**: - 更新前验证原始记录是否存在 - 更新后验证新记录是否创建成功 - 使用返回的新ID更新前端界面 --- ### 3.2 `/api/v0/training_data/upload` 新增(重点设计) **功能描述:** - 支持上传多种格式的训练数据文件(txt/md/sql/json 或无扩展名) - 通过 `file_type` 参数指定文件类型,不再依赖文件扩展名 - 文件大小限制500KB(约包含100-500条记录) - 对文件内容进行格式验证,确保符合指定类型要求 - 复用现有trainer模块的成熟功能 - 提供详细的处理结果反馈 **支持的文件类型:** API 支持以下5种文件类型,通过 `file_type` 参数指定: | file_type | 支持扩展名 | 内容格式验证 | 复用函数 | 解析方式 | |-----------|------------|-------------|---------|-----------| | `ddl` | **ddl**/sql/txt/无扩展名 | 必须包含 `CREATE` 语句(大小写不敏感) | `train_ddl_statements()` | 建表语句,以";"为分隔符 | | `markdown` | **md/markdown** | 必须包含 `##` 标题 | `train_documentation_blocks()` | Markdown以"##"为单位分割 | | `sql_pair_json` | **json**/txt/无扩展名 | 必须是有效JSON格式,包含question/sql字段 | `train_json_question_sql_pairs()` | JSON问答对训练 | | `sql_pair` | **sql**/txt/无扩展名 | 必须包含 `Question:`/`SQL:` 格式 | `train_formatted_question_sql_pairs()` | 格式化问答对训练 | | `sql` | **sql**/txt/无扩展名 | 必须包含 `;` 分隔符 | `train_sql_examples()` | SQL示例,以";"为分隔符 | **文件类型详细说明:** **1. DDL文件 (file_type: `ddl`):** 建表语句,必须包含CREATE语句,以";"为分隔符。 ```sql -- 建表语句 CREATE TABLE users ( id SERIAL PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(255) UNIQUE ); CREATE TABLE orders ( id SERIAL PRIMARY KEY, user_id INTEGER REFERENCES users(id), amount DECIMAL(10,2) ); ``` **2. Markdown文档 (file_type: `markdown`):** 必须包含"##"标题,以"##"为单位进行分割: ```markdown ## users表 用户基本信息表,存储系统用户数据。 字段说明 - id: 用户唯一标识 - name: 用户姓名 - email: 电子邮箱 ## orders表 订单信息表,记录用户购买记录。 ``` **3. JSON问答对 (file_type: `sql_pair_json`):** 必须是有效JSON格式,"question"/"sql" 大小写不敏感。 ```json [ { "question": "如何查询活跃用户数量?", "sql": "SELECT COUNT(*) FROM users WHERE active = true;" }, { "question": "如何统计每个分类的产品数量?", "sql": "SELECT category, COUNT(*) FROM products GROUP BY category;" } ] ``` **4. 格式化问答对 (file_type: `sql_pair`):** 必须包含Question/SQL格式,大小写不敏感,问答对之间用空行分隔。 ``` Question: 如何查询观影人数超过 10000 的热门电影? SQL: SELECT name, viewers FROM movies WHERE viewers > 10000; Question: 如何统计每个分类的产品数量? SQL: SELECT category, COUNT(*) FROM products GROUP BY category; ``` **5. SQL示例 (file_type: `sql`):** 纯SQL语句,必须包含";"分隔符,系统会自动生成对应问题。 ```sql SELECT COUNT(*) FROM users WHERE active = true; SELECT category, COUNT(*) FROM products GROUP BY category; ``` **API规格:** **请求参数:** - `file`: 上传的文件(multipart/form-data) - `file_type`: 文件类型(必需),可选值:`ddl`、`markdown`、`sql_pair_json`、`sql_pair`、`sql` - 文件大小限制:最大500KB - 文件编码:UTF-8 - 支持的文件扩展名:根据 `file_type` 动态确定(详见上表) **文件扩展名验证规则:** 系统根据 `file_type` 参数动态确定允许的文件扩展名: 1. **DDL类型**:允许 `.ddl`、`.sql`、`.txt` 扩展名或无扩展名 2. **Markdown类型**:仅允许 `.md`、`.markdown` 扩展名(不支持无扩展名) 3. **SQL_PAIR_JSON类型**:允许 `.json`、`.txt` 扩展名或无扩展名 4. **SQL_PAIR类型**:允许 `.sql`、`.txt` 扩展名或无扩展名 5. **SQL类型**:允许 `.sql`、`.txt` 扩展名或无扩展名 **扩展名验证错误示例:** ``` 文件类型 ddl 不支持的文件扩展名:pdf,支持的扩展名:ddl, sql, txt 或无扩展名 文件类型 markdown 不支持的文件扩展名:txt,支持的扩展名:md, markdown ``` **内容格式验证规则:** 1. **DDL类型验证**: - 文件内容必须包含 `CREATE` 关键字(大小写不敏感) - 验证失败返回:`"文件内容不符合DDL格式,必须包含CREATE语句"` 2. **Markdown类型验证**: - 文件内容必须包含 `##` 标题标记 - 验证失败返回:`"文件内容不符合Markdown格式,必须包含##标题"` 3. **SQL_PAIR_JSON类型验证**: - 文件必须是有效的JSON格式 - JSON内容必须包含question和sql字段 - 验证失败返回:`"文件内容不符合JSON问答对格式"` 4. **SQL_PAIR类型验证**: - 文件内容必须包含 `Question:` 和 `SQL:` 标记(大小写不敏感) - 验证失败返回:`"文件内容不符合问答对格式,必须包含Question:和SQL:"` 5. **SQL类型验证**: - 文件内容必须包含 `;` 分隔符 - 验证失败返回:`"文件内容不符合SQL格式,必须包含;分隔符"` **实现代码框架:** ```python import json import re from werkzeug.utils import secure_filename from data_pipeline.trainer.run_training import ( train_ddl_statements, train_documentation_blocks, train_json_question_sql_pairs, train_formatted_question_sql_pairs, train_sql_examples ) from common.result import success_response, bad_request_response, internal_error_response def get_allowed_extensions(file_type: str) -> list: """根据文件类型返回允许的扩展名""" type_specific_extensions = { 'ddl': ['ddl', 'sql', 'txt', ''], # 支持无扩展名 'markdown': ['md', 'markdown'], # 不支持无扩展名 'sql_pair_json': ['json', 'txt', ''], # 支持无扩展名 'sql_pair': ['sql', 'txt', ''], # 支持无扩展名 'sql': ['sql', 'txt', ''] # 支持无扩展名 } return type_specific_extensions.get(file_type, []) @app.route('/api/v0/training_data/upload', methods=['POST']) def upload_training_data(): """上传训练数据文件API - 支持多种文件格式的自动解析和导入""" try: # 1. 参数验证 if 'file' not in request.files: return jsonify(bad_request_response("未提供文件")) file = request.files['file'] if file.filename == '': return jsonify(bad_request_response("未选择文件")) # 获取file_type参数 file_type = request.form.get('file_type') if not file_type: return jsonify(bad_request_response("缺少必需参数:file_type")) # 验证file_type参数 valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql'] if file_type not in valid_file_types: return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}")) # 2. 文件大小验证 (500KB) file.seek(0, 2) file_size = file.tell() file.seek(0) if file_size > 500 * 1024: # 500KB return jsonify(bad_request_response("文件大小不能超过500KB")) # 3. 验证文件扩展名(基于file_type) filename = secure_filename(file.filename) allowed_extensions = get_allowed_extensions(file_type) file_ext = filename.split('.')[-1].lower() if '.' in filename else '' if file_ext not in allowed_extensions: # 构建友好的错误信息 non_empty_extensions = [ext for ext in allowed_extensions if ext] if '' in allowed_extensions: ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名" else: ext_message = ', '.join(non_empty_extensions) return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}")) # 4. 读取文件内容并验证格式 file.seek(0) content = file.read().decode('utf-8') # 格式验证 validation_result = validate_file_content(content, file_type) if not validation_result['valid']: return jsonify(bad_request_response(validation_result['error'])) # 5. 创建临时文件(复用现有函数需要文件路径) temp_file_path = None try: with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file: tmp_file.write(content) temp_file_path = tmp_file.name # 6. 根据文件类型调用现有的训练函数 if file_type == 'ddl': train_ddl_statements(temp_file_path) elif file_type == 'markdown': train_documentation_blocks(temp_file_path) elif file_type == 'sql_pair_json': train_json_question_sql_pairs(temp_file_path) elif file_type == 'sql_pair': train_formatted_question_sql_pairs(temp_file_path) elif file_type == 'sql': train_sql_examples(temp_file_path) return jsonify(success_response( response_text=f"文件上传并训练成功:{filename}", data={ "filename": filename, "file_type": file_type, "file_size": file_size, "status": "completed" } )) except Exception as e: logger.error(f"训练失败: {str(e)}") return jsonify(internal_error_response(f"训练失败: {str(e)}")) finally: # 清理临时文件 if temp_file_path and os.path.exists(temp_file_path): try: os.unlink(temp_file_path) except Exception as e: logger.warning(f"清理临时文件失败: {str(e)}") except Exception as e: logger.error(f"文件上传失败: {str(e)}") return jsonify(internal_error_response(f"文件上传失败: {str(e)}")) def validate_file_content(content: str, file_type: str) -> dict: """验证文件内容格式""" try: if file_type == 'ddl': # 检查是否包含CREATE语句 if not re.search(r'\bCREATE\b', content, re.IGNORECASE): return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'} elif file_type == 'markdown': # 检查是否包含##标题 if '##' not in content: return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'} elif file_type == 'sql_pair_json': # 检查是否为有效JSON try: data = json.loads(content) if not isinstance(data, list) or not data: return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'} # 检查是否包含question和sql字段 for item in data: if not isinstance(item, dict): return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'} has_question = any(key.lower() == 'question' for key in item.keys()) has_sql = any(key.lower() == 'sql' for key in item.keys()) if not has_question or not has_sql: return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'} except json.JSONDecodeError: return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'} elif file_type == 'sql_pair': # 检查是否包含Question:和SQL: if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE): return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'} if not re.search(r'\bSQL\s*:', content, re.IGNORECASE): return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'} elif file_type == 'sql': # 检查是否包含;分隔符 if ';' not in content: return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'} return {'valid': True} except Exception as e: return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'} ``` **返回格式:** **成功响应:** ```json { "code": 200, "success": true, "message": "操作成功", "data": { "response": "文件上传并训练成功:training_data.sql", "filename": "training_data.sql", "file_type": "sql", "file_size": 12345, "status": "completed" } } ``` **错误响应:** ```json { "code": 400, "success": false, "message": "请求参数错误", "data": { "response": "文件内容不符合SQL格式,必须包含;分隔符", "error_type": "INVALID_PARAMS", "timestamp": "2025-01-15T10:30:00.000Z", "can_retry": false } } ``` **安全性和性能考虑:** 1. **文件处理优化**: - 使用临时文件处理,确保内存安全 - 文件大小限制500KB,避免内存压力 - 自动清理临时文件,防止磁盘空间泄露 - 完整的错误处理和日志记录 2. **类型验证机制**: - 基于 `file_type` 参数动态确定允许的扩展名 - 对文件内容进行格式验证 - 大小写不敏感的格式处理 - 详细的验证错误信息,包含具体的文件类型和支持的扩展名 3. **复用成熟功能**: - 直接使用 `data_pipeline.trainer.run_training` 中的成熟函数 - 复用 `train_ddl_statements`、`train_documentation_blocks` 等处理逻辑 - 调用 `train_json_question_sql_pairs` 和 `train_formatted_question_sql_pairs` 处理问答对 - 继承批处理优化机制,保持训练逻辑的一致性 - 避免重复造轮子,提高代码维护性 4. **错误处理**: - 详细的错误日志记录 - 文件格式验证失败时的明确提示 - 基于文件类型的精确扩展名验证和友好提示 - 临时文件清理保证不留垃圾文件 5. **性能优化**: - 复用现有的高效解析函数 - 批处理机制提高训练效率 - 避免重复开发,减少维护成本 **使用建议:** 1. **文件准备**: - 确保文件使用UTF-8编码 - 控制文件大小在500KB以内 - 按照指定的格式准备文件内容 2. **文件类型选择**: - `ddl`:用于训练表结构信息,支持 .ddl/.sql/.txt 扩展名或无扩展名,必须包含CREATE语句 - `markdown`:用于训练表和字段说明,仅支持 .md/.markdown 扩展名,必须包含##标题 - `sql_pair_json`:用于训练结构化问答数据,支持 .json/.txt 扩展名或无扩展名,必须是有效JSON格式 - `sql_pair`:用于训练文本格式问答数据,支持 .sql/.txt 扩展名或无扩展名,必须包含Question:/SQL:格式 - `sql`:用于训练SQL示例,支持 .sql/.txt 扩展名或无扩展名,必须包含;分隔符,系统会自动生成问题 3. **最佳实践**: - 建议分批上传,每次处理一种类型的文件 - 上传前检查文件格式的正确性 - 明确指定正确的 `file_type` 参数 - 关注返回结果中的处理状态信息 **Postman测试示例:** ``` POST /api/v0/training_data/upload Content-Type: multipart/form-data form-data: - file: [选择文件] - file_type: sql_pair 示例文件内容: Question: 如何查询观影人数超过 10000 的热门电影? SQL: SELECT name, viewers FROM movies WHERE viewers > 10000; Question: 如何统计每个分类的产品数量? SQL: SELECT category, COUNT(*) FROM products GROUP BY category; ``` --- ### 3.3 `/api/v0/training_data/combine` (新增的API) **功能描述:** - 合并 langchain_pg_embedding 表中相同 document 字段的重复记录 - 支持按集合名称过滤合并范围 - 提供预览模式(dry_run)和实际执行模式 - 支持多种保留策略(first/last/by_metadata_time) **技术分析:** 训练数据写入机制中,不同类型的数据在 document 字段中存储不同格式的内容: - SQL类型:JSON格式 `{"question": "...", "sql": "..."}` - DDL类型:直接存储DDL语句原文 - Documentation类型:直接存储文档内容原文 - Error_SQL类型:JSON格式错误SQL示例 合并策略采用 document 字段精确匹配,因为: 1. embedding 字段是向量,比较成本高且可能有微小差异 2. document 字段存储确定性文本内容,便于精确比较 3. 相同的 document 内容必然产生相同的训练效果 **API规格:** **请求参数:** ```json { "collection_names": ["sql", "ddl", "documentation", "error_sql"], "dry_run": true, "keep_strategy": "first" } ``` **参数说明:** - `collection_names`: 要处理的集合名称数组,支持 ["sql", "ddl", "documentation", "error_sql"] - `dry_run`: 是否为预览模式,默认 true(安全) - `keep_strategy`: 保留策略 - `"first"`: 保留第一条记录(按ID排序,推荐默认) - `"last"`: 保留最后一条记录(按ID排序) - `"by_metadata_time"`: 按 cmetadata.createdat 时间排序保留最新记录 **实现代码框架:** ```python @app.route('/api/v0/training_data/combine', methods=['POST']) def combine_training_data(): try: # 1. 参数验证 data = request.get_json() if not data: return jsonify(bad_request_response("请求体不能为空")) collection_names = data.get('collection_names', []) if not collection_names or not isinstance(collection_names, list): return jsonify(bad_request_response("collection_names 参数必须是非空数组")) # 验证集合名称 valid_collections = ['sql', 'ddl', 'documentation', 'error_sql'] invalid_collections = [name for name in collection_names if name not in valid_collections] if invalid_collections: return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}")) dry_run = data.get('dry_run', True) keep_strategy = data.get('keep_strategy', 'first') if keep_strategy not in ['first', 'last', 'by_metadata_time']: return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'")) # 2. 获取数据库连接 conn = get_db_connection() cursor = conn.cursor() # 3. 查找重复记录 duplicate_groups = [] total_before = 0 total_duplicates = 0 collections_stats = {} for collection_name in collection_names: # 获取集合ID cursor.execute( "SELECT uuid FROM langchain_pg_collection WHERE name = %s", (collection_name,) ) collection_result = cursor.fetchone() if not collection_result: continue collection_id = collection_result[0] # 统计该集合的记录数 cursor.execute( "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s", (collection_id,) ) collection_before = cursor.fetchone()[0] total_before += collection_before # 查找重复记录 cursor.execute(""" SELECT document, COUNT(*) as duplicate_count, array_agg(id ORDER BY %s) as record_ids FROM langchain_pg_embedding WHERE collection_id = %s GROUP BY document HAVING COUNT(*) > 1 """ % ( "id" if keep_strategy in ['first', 'last'] else "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id", "%s" ), (collection_id,)) collection_duplicates = 0 for row in cursor.fetchall(): document_content, duplicate_count, record_ids = row collection_duplicates += duplicate_count - 1 # 减去要保留的一条 # 根据保留策略选择要保留的记录 if keep_strategy == 'first': keep_id = record_ids[0] remove_ids = record_ids[1:] elif keep_strategy == 'last': keep_id = record_ids[-1] remove_ids = record_ids[:-1] else: # by_metadata_time keep_id = record_ids[0] # 已经按时间排序 remove_ids = record_ids[1:] duplicate_groups.append({ "collection_name": collection_name, "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content, "duplicate_count": duplicate_count, "kept_record_id" if not dry_run else "records_to_keep": keep_id, "removed_record_ids" if not dry_run else "records_to_remove": remove_ids }) total_duplicates += collection_duplicates collections_stats[collection_name] = { "before": collection_before, "after": collection_before - collection_duplicates, "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates } # 4. 执行合并操作(如果不是dry_run) if not dry_run: try: conn.autocommit = False # 开始事务 for group in duplicate_groups: remove_ids = group["removed_record_ids"] if remove_ids: cursor.execute( "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)", (remove_ids,) ) conn.commit() except Exception as e: conn.rollback() return jsonify(internal_error_response(f"合并操作失败: {str(e)}")) finally: conn.autocommit = True # 5. 构建响应 total_after = total_before - total_duplicates summary = { "total_records_before": total_before, "total_records_after": total_after, "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates, "collections_stats": collections_stats } if dry_run: response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录" data_key = "duplicate_groups" else: response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录" data_key = "merged_groups" return jsonify(success_response( response_text=response_text, data={ "dry_run": dry_run, "collections_processed": collection_names, "summary": summary, data_key: duplicate_groups } )) except Exception as e: return jsonify(internal_error_response(f"合并操作失败: {str(e)}")) finally: if 'cursor' in locals(): cursor.close() if 'conn' in locals(): conn.close() ``` **返回格式(dry_run: true):** ```json { "code": 200, "success": true, "message": "操作成功", "data": { "response": "发现 50 条重复记录,预计删除后将从 1000 条减少到 950 条记录", "dry_run": true, "collections_processed": ["sql", "ddl"], "summary": { "total_records_before": 1000, "total_records_after": 950, "duplicates_to_remove": 50, "collections_stats": { "sql": { "before": 600, "after": 580, "duplicates_to_remove": 20 }, "ddl": { "before": 400, "after": 370, "duplicates_to_remove": 30 } } }, "duplicate_groups": [ { "collection_name": "sql", "document_content": "SELECT * FROM users WHERE id = ?", "duplicate_count": 3, "records_to_keep": "uuid1-sql", "records_to_remove": ["uuid2-sql", "uuid3-sql"] } ] } } ``` **返回格式(dry_run: false):** ```json { "code": 200, "success": true, "message": "操作成功", "data": { "response": "成功合并重复记录,删除了 50 条重复记录,从 1000 条减少到 950 条记录", "dry_run": false, "collections_processed": ["sql", "ddl"], "summary": { "total_records_before": 1000, "total_records_after": 950, "duplicates_removed": 50, "collections_stats": { "sql": { "before": 600, "after": 580, "duplicates_removed": 20 }, "ddl": { "before": 400, "after": 370, "duplicates_removed": 30 } } }, "merged_groups": [ { "collection_name": "sql", "document_content": "SELECT * FROM users WHERE id = ?", "duplicate_count": 3, "kept_record_id": "uuid1-sql", "removed_record_ids": ["uuid2-sql", "uuid3-sql"] } ] } } ``` **错误响应示例:** ```json { "code": 400, "success": false, "message": "请求参数错误", "data": { "response": "collection_names 参数必须是非空数组", "error_type": "INVALID_PARAMS", "timestamp": "2025-01-15T10:30:00.000Z", "can_retry": false } } ``` **安全性和性能考虑:** 1. **事务安全**: - 整个合并过程在数据库事务中执行 - 发生错误时自动回滚 - 避免数据不一致状态 2. **预览模式**: - 默认启用 dry_run 模式 - 用户可以先查看合并计划再决定是否执行 - 提供详细的影响分析 3. **保留策略**: - 支持多种保留策略满足不同需求 - first 策略最简单可靠,推荐作为默认 - by_metadata_time 策略使用真实创建时间 4. **性能优化**: - 使用数组聚合函数减少数据库查询次数 - 批量删除操作提高效率 - 分集合处理避免长时间锁表 5. **错误处理**: - 详细的参数验证 - 完整的异常处理和错误信息 - 数据库连接的正确关闭 **使用建议:** 1. **推荐流程**: ```bash # 1. 先预览合并计划 POST /api/v0/training_data/combine { "collection_names": ["sql", "ddl"], "dry_run": true, "keep_strategy": "first" } # 2. 确认无误后执行实际合并 POST /api/v0/training_data/combine { "collection_names": ["sql", "ddl"], "dry_run": false, "keep_strategy": "first" } ``` 2. **保留策略选择**: - 一般情况:使用 `"first"` 策略 - 需要保留最新数据:使用 `"by_metadata_time"` 策略 - 测试环境:可以使用 `"last"` 策略 3. **批量处理**: - 大量数据时建议分批处理不同集合 - 可以先处理较小的集合验证效果 - 重要数据建议先备份 --- ## 4. 补充建议 ### 4.1 错误处理增强 建议在所有API中添加更详细的错误处理: ```python # 通用错误处理装饰器 def handle_api_errors(f): @wraps(f) def decorated_function(*args, **kwargs): try: return f(*args, **kwargs) except ValueError as e: return jsonify(bad_request_response(str(e))) except Exception as e: logger.error(f"API错误: {str(e)}", exc_info=True) return jsonify(internal_error_response(f"系统错误: {str(e)}")) return decorated_function ``` ### 4.2 批量操作优化 对于大量数据的处理,建议引入任务队列: ```python # 异步任务支持 @app.route('/api/v0/training_data/upload_async', methods=['POST']) def upload_training_data_async(): # 创建后台任务 task_id = str(uuid.uuid4()) # 将任务放入队列 # 返回任务ID供客户端查询进度 return jsonify(success_response( response_text="文件上传任务已创建", data={"task_id": task_id} )) @app.route('/api/v0/training_data/task_status/', methods=['GET']) def get_task_status(task_id): # 查询任务状态 pass ``` ### 4.3 数据验证增强 建议添加更严格的数据验证: ```python # SQL语法验证 def validate_sql(sql): try: import sqlparse parsed = sqlparse.parse(sql) if not parsed: raise ValueError("SQL语法错误") return True except Exception as e: raise ValueError(f"SQL验证失败: {str(e)}") # 数据类型验证 def validate_training_data(data_type, content): if data_type == 'sql': return validate_sql(content) elif data_type == 'ddl': return validate_ddl(content) # 其他验证逻辑 ``` --- ## 5. 实施建议 ### 5.1 开发顺序 1. **第一阶段**:实现增强的stats API(相对简单) 2. **第二阶段**:实现update API(中等复杂度) 3. **第三阶段**:实现upload API(最复杂,需要充分测试) ### 5.2 测试策略 1. **单元测试**:每个解析函数和训练函数 2. **集成测试**:完整的API调用流程 3. **性能测试**:大文件上传和处理 4. **错误测试**:各种异常情况的处理 ### 5.3 部署注意事项 1. **文件上传限制**:配置Nginx/Apache的文件大小限制 2. **临时目录**:确保有足够的磁盘空间 3. **数据库连接**:优化数据库连接池设置 4. **日志监控**:添加详细的操作日志 --- ## 6. 结论 本改造方案通过充分复用现有代码,提供了更加完整和实用的解决方案: 1. **update API**:简化为单条更新,采用先删除后插入策略,返回新记录ID 2. **upload API**:复用 `data_pipeline/trainer` 模块的成熟函数,支持多种文件格式的自动解析和导入 3. **combine API**:支持合并重复训练数据,提供预览和执行模式 **主要优势:** 1. **避免重复造轮子**: - 直接复用 `run_training.py` 中的成熟解析函数 - 保持与现有训练逻辑的一致性 - 减少代码维护成本 2. **提高代码质量**: - 复用经过验证的文件处理逻辑 - 继承现有的错误处理机制 - 保持日志格式的统一性 3. **性能优化**: - 利用现有的批处理机制 - 复用高效的文件解析算法 - 减少开发和测试时间 特别是upload API的设计,通过复用现有的trainer模块,不仅避免了重复开发,还确保了与现有系统的完美兼容性,可以满足生产环境的需求。