|
@@ -0,0 +1,1171 @@
|
|
|
|
+# unified_api.py 训练数据管理API改造方案(详细版)
|
|
|
|
+
|
|
|
|
+## 1. 需求概述
|
|
|
|
+
|
|
|
|
+本方案基于对 `unified_api.py` 文件的深入分析,提出以下两个API改造与新增需求:
|
|
|
|
+
|
|
|
|
+- **需求一**:新增 `/api/v0/training_data/update`,支持先删除再插入的训练数据更新操作
|
|
|
|
+- **需求二**:新增 `/api/v0/training_data/upload`,支持上传多种格式文件并自动解析入库
|
|
|
|
+- **需求三**:新增 `/api/v0/training_data/combine`,合并相同的 langchain_pg_embedding.document 记录
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+## 2. 当前实现分析
|
|
|
|
+
|
|
|
|
+### 2.1 现有API结构
|
|
|
|
+
|
|
|
|
+`unified_api.py` 目前实现了4个训练数据相关的API端点:
|
|
|
|
+
|
|
|
|
+1. **`GET /api/v0/training_data/stats`** (lines 1205-1234) - 基础统计信息
|
|
|
|
+2. **`POST /api/v0/training_data/query`** (lines 1236-1294) - 分页查询
|
|
|
|
+3. **`POST /api/v0/training_data/create`** (lines 1296-1404) - 批量创建(最多50条)
|
|
|
|
+4. **`POST /api/v0/training_data/delete`** (lines 1406-1471) - 批量删除(最多50条)
|
|
|
|
+
|
|
|
|
+### 2.2 数据库架构
|
|
|
|
+
|
|
|
|
+**表结构关系:**
|
|
|
|
+```sql
|
|
|
|
+-- langchain_pg_collection (集合表)
|
|
|
|
+CREATE TABLE langchain_pg_collection (
|
|
|
|
+ uuid uuid PRIMARY KEY,
|
|
|
|
+ name varchar NOT NULL UNIQUE,
|
|
|
|
+ cmetadata json
|
|
|
|
+);
|
|
|
|
+
|
|
|
|
+-- langchain_pg_embedding (训练数据表)
|
|
|
|
+CREATE TABLE langchain_pg_embedding (
|
|
|
|
+ id varchar PRIMARY KEY,
|
|
|
|
+ collection_id uuid REFERENCES langchain_pg_collection(uuid) ON DELETE CASCADE,
|
|
|
|
+ embedding vector,
|
|
|
|
+ document varchar,
|
|
|
|
+ cmetadata jsonb
|
|
|
|
+);
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**集合名称映射:**
|
|
|
|
+- `sql` 集合 → SQL查询和问答对
|
|
|
|
+- `ddl` 集合 → 数据库表结构定义
|
|
|
|
+- `documentation` 集合 → 表和字段文档
|
|
|
|
+- `error_sql` 集合 → 错误SQL示例
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+## 3. 详细改造方案
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+### 3.1 `/api/v0/training_data/update` (新增的API)
|
|
|
|
+
|
|
|
|
+**功能描述:**
|
|
|
|
+- 支持单条更新训练数据
|
|
|
|
+- 采用先删除后插入的策略确保数据一致性
|
|
|
|
+- 更新方式是先删除再插入,删除是以 langchain_pg_embedding.id字段为依据的
|
|
|
|
+- langchain_pg_embedding.id字段是varchar类型,它是uuid+类型的格式,比如:"50bb6b17-d5be-48ab-8125-58ab8bb076e8-sql"
|
|
|
|
+- 返回新创建记录的ID,便于前端更新界面
|
|
|
|
+
|
|
|
|
+**WEB UI使用流程:**
|
|
|
|
+1. 用户选中一行记录,点击修改按钮
|
|
|
|
+2. 前端显示编辑表单,用户修改内容
|
|
|
|
+3. 提交时,根据这一行的id,删除这条记录
|
|
|
|
+4. 然后调用create参数,生成一条新的记录
|
|
|
|
+5. 返回的JSON中包括这条记录的新的id
|
|
|
|
+
|
|
|
|
+**API规格:**
|
|
|
|
+
|
|
|
|
+**请求参数:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
|
|
|
|
+ "training_data_type": "sql",
|
|
|
|
+ "question": "如何查询所有用户?",
|
|
|
|
+ "sql": "SELECT * FROM users;"
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**参数说明:**
|
|
|
|
+- `id`: 要删除的原始记录ID(必需)
|
|
|
|
+- `training_data_type`: 训练数据类型,支持 "sql", "ddl", "documentation", "error_sql"
|
|
|
|
+- 其他字段根据类型不同而不同:
|
|
|
|
+ - `sql` 类型:`question`(可选), `sql`(必需)
|
|
|
|
+ - `error_sql` 类型:`question`(必需), `sql`(必需)
|
|
|
|
+ - `documentation` 类型:`content`(必需)
|
|
|
|
+ - `ddl` 类型:`ddl`(必需)
|
|
|
|
+
|
|
|
|
+**实现代码框架:**
|
|
|
|
+```python
|
|
|
|
+@app.route('/api/v0/training_data/update', methods=['POST'])
|
|
|
|
+def training_data_update():
|
|
|
|
+ """更新训练数据API - 支持单条更新,采用先删除后插入策略"""
|
|
|
|
+ try:
|
|
|
|
+ req = request.get_json(force=True)
|
|
|
|
+
|
|
|
|
+ # 1. 参数验证
|
|
|
|
+ original_id = req.get('id')
|
|
|
|
+ if not original_id:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="缺少必需参数:id",
|
|
|
|
+ missing_params=["id"]
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ training_type = req.get('training_data_type')
|
|
|
|
+ if not training_type:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="缺少必需参数:training_data_type",
|
|
|
|
+ missing_params=["training_data_type"]
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ # 2. 先删除原始记录
|
|
|
|
+ try:
|
|
|
|
+ success = vn.remove_training_data(original_id)
|
|
|
|
+ if not success:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text=f"原始记录 {original_id} 不存在或删除失败"
|
|
|
|
+ )), 400
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return jsonify(internal_error_response(
|
|
|
|
+ response_text=f"删除原始记录失败: {str(e)}"
|
|
|
|
+ )), 500
|
|
|
|
+
|
|
|
|
+ # 3. 根据类型验证和准备新数据
|
|
|
|
+ try:
|
|
|
|
+ if training_type == 'sql':
|
|
|
|
+ sql = req.get('sql')
|
|
|
|
+ if not sql:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="SQL字段是必需的",
|
|
|
|
+ missing_params=["sql"]
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ # SQL语法检查
|
|
|
|
+ is_valid, error_msg = validate_sql_syntax(sql)
|
|
|
|
+ if not is_valid:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text=f"SQL语法错误: {error_msg}"
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ question = req.get('question')
|
|
|
|
+ if question:
|
|
|
|
+ training_id = vn.train(question=question, sql=sql)
|
|
|
|
+ else:
|
|
|
|
+ training_id = vn.train(sql=sql)
|
|
|
|
+
|
|
|
|
+ elif training_type == 'error_sql':
|
|
|
|
+ question = req.get('question')
|
|
|
|
+ sql = req.get('sql')
|
|
|
|
+ if not question or not sql:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="question和sql字段都是必需的",
|
|
|
|
+ missing_params=["question", "sql"]
|
|
|
|
+ )), 400
|
|
|
|
+ training_id = vn.train_error_sql(question=question, sql=sql)
|
|
|
|
+
|
|
|
|
+ elif training_type == 'documentation':
|
|
|
|
+ content = req.get('content')
|
|
|
|
+ if not content:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="content字段是必需的",
|
|
|
|
+ missing_params=["content"]
|
|
|
|
+ )), 400
|
|
|
|
+ training_id = vn.train(documentation=content)
|
|
|
|
+
|
|
|
|
+ elif training_type == 'ddl':
|
|
|
|
+ ddl = req.get('ddl')
|
|
|
|
+ if not ddl:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text="ddl字段是必需的",
|
|
|
|
+ missing_params=["ddl"]
|
|
|
|
+ )), 400
|
|
|
|
+ training_id = vn.train(ddl=ddl)
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ return jsonify(bad_request_response(
|
|
|
|
+ response_text=f"不支持的训练数据类型: {training_type}"
|
|
|
|
+ )), 400
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return jsonify(internal_error_response(
|
|
|
|
+ response_text=f"创建新训练数据失败: {str(e)}"
|
|
|
|
+ )), 500
|
|
|
|
+
|
|
|
|
+ # 4. 获取更新后的总记录数
|
|
|
|
+ current_total = get_total_training_count()
|
|
|
|
+
|
|
|
|
+ return jsonify(success_response(
|
|
|
|
+ response_text="训练数据更新成功",
|
|
|
|
+ data={
|
|
|
|
+ "original_id": original_id,
|
|
|
|
+ "new_training_id": training_id,
|
|
|
|
+ "type": training_type,
|
|
|
|
+ "current_total_count": current_total
|
|
|
|
+ }
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"training_data_update执行失败: {str(e)}")
|
|
|
|
+ return jsonify(internal_error_response(
|
|
|
|
+ response_text="更新训练数据失败,请稍后重试"
|
|
|
|
+ )), 500
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**返回格式:**
|
|
|
|
+
|
|
|
|
+**成功响应:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 200,
|
|
|
|
+ "success": true,
|
|
|
|
+ "message": "操作成功",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "训练数据更新成功",
|
|
|
|
+ "original_id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
|
|
|
|
+ "new_training_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890-sql",
|
|
|
|
+ "type": "sql",
|
|
|
|
+ "current_total_count": 1250
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**错误响应示例:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 400,
|
|
|
|
+ "success": false,
|
|
|
|
+ "message": "请求参数错误",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "原始记录 6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql 不存在或删除失败",
|
|
|
|
+ "error_type": "INVALID_PARAMS",
|
|
|
|
+ "timestamp": "2025-01-15T10:30:00.000Z",
|
|
|
|
+ "can_retry": false
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**错误响应示例:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 400,
|
|
|
|
+ "success": false,
|
|
|
|
+ "message": "请求参数错误",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "缺少必需参数:data",
|
|
|
|
+ "error_type": "INVALID_PARAMS",
|
|
|
|
+ "timestamp": "2025-01-15T10:30:00.000Z",
|
|
|
|
+ "can_retry": false
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**安全性和性能考虑:**
|
|
|
|
+
|
|
|
|
+1. **事务安全**:
|
|
|
|
+ - 删除和插入操作在同一个函数中执行
|
|
|
|
+ - 如果插入失败,原始记录已被删除,需要回滚或重试机制
|
|
|
|
+ - 建议在应用层实现重试逻辑
|
|
|
|
+
|
|
|
|
+2. **数据一致性**:
|
|
|
|
+ - 先删除后插入确保ID的唯一性
|
|
|
|
+ - 新记录会获得新的UUID,避免ID冲突
|
|
|
|
+ - 保持训练数据类型的正确性
|
|
|
|
+
|
|
|
|
+3. **错误处理**:
|
|
|
|
+ - 详细的参数验证
|
|
|
|
+ - 原始记录存在性检查
|
|
|
|
+ - SQL语法验证(针对sql类型)
|
|
|
|
+ - 完整的异常处理和错误信息
|
|
|
|
+
|
|
|
|
+4. **单条处理**:
|
|
|
|
+ - 只支持单条更新,简化逻辑
|
|
|
|
+ - 减少并发冲突的可能性
|
|
|
|
+ - 更精确的错误定位和处理
|
|
|
|
+
|
|
|
|
+5. **性能优化**:
|
|
|
|
+ - 复用现有的训练函数
|
|
|
|
+ - 避免重复的数据库连接
|
|
|
|
+ - 合理的错误处理和日志记录
|
|
|
|
+
|
|
|
|
+**使用建议:**
|
|
|
|
+
|
|
|
|
+1. **前端集成**:
|
|
|
|
+ ```javascript
|
|
|
|
+ // 更新训练数据的示例
|
|
|
|
+ const updateTrainingData = async (originalId, newData) => {
|
|
|
|
+ const response = await fetch('/api/v0/training_data/update', {
|
|
|
|
+ method: 'POST',
|
|
|
|
+ headers: { 'Content-Type': 'application/json' },
|
|
|
|
+ body: JSON.stringify({
|
|
|
|
+ id: originalId,
|
|
|
|
+ training_data_type: 'sql',
|
|
|
|
+ question: newData.question,
|
|
|
|
+ sql: newData.sql
|
|
|
|
+ })
|
|
|
|
+ });
|
|
|
|
+
|
|
|
|
+ const result = await response.json();
|
|
|
|
+ if (result.success) {
|
|
|
|
+ // 更新成功,使用新的training_id更新界面
|
|
|
|
+ const newId = result.data.new_training_id;
|
|
|
|
+ updateUIWithNewId(newId);
|
|
|
|
+ }
|
|
|
|
+ };
|
|
|
|
+ ```
|
|
|
|
+
|
|
|
|
+2. **错误处理**:
|
|
|
|
+ - 检查返回的 `success` 字段
|
|
|
|
+ - 处理各种错误情况(参数错误、记录不存在、SQL语法错误等)
|
|
|
|
+ - 显示详细的错误信息给用户
|
|
|
|
+
|
|
|
|
+3. **最佳实践**:
|
|
|
|
+ - 更新前验证原始记录是否存在
|
|
|
|
+ - 更新后验证新记录是否创建成功
|
|
|
|
+ - 使用返回的新ID更新前端界面
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+### 3.2 `/api/v0/training_data/upload` 新增(重点设计)
|
|
|
|
+
|
|
|
|
+**功能描述:**
|
|
|
|
+
|
|
|
|
+- 支持上传多种格式的训练数据文件(.ddl/.md/.json/.sql)
|
|
|
|
+- 直接内存处理,无需临时文件存储
|
|
|
|
+- 文件大小限制500KB(约包含100-500条记录)
|
|
|
|
+- 自动识别文件类型并解析内容
|
|
|
|
+- 复用现有trainer模块的成熟功能
|
|
|
|
+- 提供详细的处理结果反馈
|
|
|
|
+
|
|
|
|
+**支持的文件类型:**
|
|
|
|
+
|
|
|
|
+基于现有 `data_pipeline/trainer` 模块的成熟功能,API 支持以下文件类型的自动识别和处理:
|
|
|
|
+
|
|
|
|
+| 文件类型 | 扩展名 | 文件命名规则 | 复用函数 | 解析方式 |
|
|
|
|
+|---------|--------|-------------|---------|-----------|
|
|
|
|
+| ddl | `.ddl` | 任意名称 | `train_ddl_file()` → `read_file_by_delimiter()` | 建表语句,以“;”为分隔符 |
|
|
|
|
+| documentation | `.md` | 任意名称 | `train_documentation_file()` → `read_markdown_file_by_sections()` | Markdown以"##"为单位分割,普通文档以"---"分隔 |
|
|
|
|
+| question_sql_pairs | `.json` | 必须包含 `_pair`或 `_pairs` | `train_json_pairs_file()` → `train_json_question_sql_pairs()` | 问答对训练(JSON格式) |
|
|
|
|
+| formatted_pairs | `.sql` | 必须包含 `_pair`或 `_pairs` | `train_formatted_pairs_file()` → `train_formatted_question_sql_pairs()` | 问答对训练(文本格式),每对Question/SQL之间使用空行分隔 |
|
|
|
|
+| sql_examples | `.sql` | 任意名称(不包含`_pair`) | `train_sql_file()` → `read_file_by_delimiter()` | SQL示例,以";"为分隔符 |
|
|
|
|
+
|
|
|
|
+**1. DDL文件 (*.ddl):**
|
|
|
|
+
|
|
|
|
+建表语句,解析时以“;”为分隔符。
|
|
|
|
+
|
|
|
|
+```sql
|
|
|
|
+-- 建表语句
|
|
|
|
+CREATE TABLE users (
|
|
|
|
+ id SERIAL PRIMARY KEY,
|
|
|
|
+ name VARCHAR(100) NOT NULL,
|
|
|
|
+ email VARCHAR(255) UNIQUE
|
|
|
|
+);
|
|
|
|
+
|
|
|
|
+CREATE TABLE orders (
|
|
|
|
+ id SERIAL PRIMARY KEY,
|
|
|
|
+ user_id INTEGER REFERENCES users(id),
|
|
|
|
+ amount DECIMAL(10,2)
|
|
|
|
+);
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**2. 文档文件 (*.md):**
|
|
|
|
+
|
|
|
|
+以"##"为单位进行分割,每个"##"是一个表名:
|
|
|
|
+
|
|
|
|
+```markdown
|
|
|
|
+## users表
|
|
|
|
+用户基本信息表,存储系统用户数据。
|
|
|
|
+字段说明
|
|
|
|
+- id: 用户唯一标识
|
|
|
|
+- name: 用户姓名
|
|
|
|
+- email: 电子邮箱
|
|
|
|
+
|
|
|
|
+## orders表
|
|
|
|
+订单信息表,记录用户购买记录。
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**3. JSON问答对 (*_pair.json):**
|
|
|
|
+
|
|
|
|
+"question"/"sql" 大小写不敏感。
|
|
|
|
+
|
|
|
|
+```json
|
|
|
|
+[
|
|
|
|
+ {
|
|
|
|
+ "question": "如何查询活跃用户数量?",
|
|
|
|
+ "sql": "SELECT COUNT(*) FROM users WHERE active = true;"
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "question": "如何统计每个分类的产品数量?",
|
|
|
|
+ "sql": "SELECT category, COUNT(*) FROM products GROUP BY category;"
|
|
|
|
+ }
|
|
|
|
+]
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**4. 格式化问答对 (*_pair.sql):**
|
|
|
|
+
|
|
|
|
+注意,Question和SQL,大小写不敏感,但是两个问答对之间要以空行作为分隔。
|
|
|
|
+
|
|
|
|
+```
|
|
|
|
+Question: 如何查询活跃用户数量?
|
|
|
|
+SQL:SELECT COUNT(*) FROM users WHERE active = true;
|
|
|
|
+
|
|
|
|
+Question: 如何统计每个分类的产品数量?
|
|
|
|
+SQL:SELECT category, COUNT(*) FROM products GROUP BY category;
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**API规格:**
|
|
|
|
+
|
|
|
|
+**请求参数:**
|
|
|
|
+- `file`: 上传的文件(multipart/form-data)
|
|
|
|
+- 文件大小限制:最大500KB
|
|
|
|
+- 文件编码:UTF-8
|
|
|
|
+- 自动识别文件类型:基于文件扩展名和命名规则
|
|
|
|
+
|
|
|
|
+**实现代码框架:**
|
|
|
|
+```python
|
|
|
|
+from werkzeug.utils import secure_filename
|
|
|
|
+from data_pipeline.trainer.run_training import (
|
|
|
|
+ read_file_by_delimiter,
|
|
|
|
+ read_markdown_file_by_sections,
|
|
|
|
+ train_json_question_sql_pairs,
|
|
|
|
+ train_formatted_question_sql_pairs
|
|
|
|
+)
|
|
|
|
+from data_pipeline.trainer.vanna_trainer import (
|
|
|
|
+ train_ddl, train_documentation, train_question_sql_pair,
|
|
|
|
+ train_sql_example, flush_training
|
|
|
|
+)
|
|
|
|
+from common.result import success_response, bad_request_response, internal_error_response
|
|
|
|
+import json
|
|
|
|
+import tempfile
|
|
|
|
+import os
|
|
|
|
+import logging
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+
|
|
|
|
+@app.route('/api/v0/training_data/upload', methods=['POST'])
|
|
|
|
+def upload_training_data():
|
|
|
|
+ try:
|
|
|
|
+ # 1. 参数验证
|
|
|
|
+ if 'file' not in request.files:
|
|
|
|
+ return jsonify(bad_request_response("未提供文件"))
|
|
|
|
+
|
|
|
|
+ file = request.files['file']
|
|
|
|
+ if file.filename == '':
|
|
|
|
+ return jsonify(bad_request_response("未选择文件"))
|
|
|
|
+
|
|
|
|
+ # 2. 文件大小验证 (500KB)
|
|
|
|
+ file.seek(0, 2)
|
|
|
|
+ file_size = file.tell()
|
|
|
|
+ file.seek(0)
|
|
|
|
+
|
|
|
|
+ if file_size > 500 * 1024: # 500KB
|
|
|
|
+ return jsonify(bad_request_response("文件大小不能超过500KB"))
|
|
|
|
+
|
|
|
|
+ # 3. 创建临时文件(复用现有函数需要文件路径)
|
|
|
|
+ filename = secure_filename(file.filename)
|
|
|
|
+ temp_file_path = None
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ with tempfile.NamedTemporaryFile(mode='w+b', delete=False, suffix=os.path.splitext(filename)[1]) as tmp_file:
|
|
|
|
+ file.save(tmp_file.name)
|
|
|
|
+ temp_file_path = tmp_file.name
|
|
|
|
+
|
|
|
|
+ # 4. 根据文件类型调用现有的训练函数
|
|
|
|
+ filename_lower = filename.lower()
|
|
|
|
+ success_count = 0
|
|
|
|
+
|
|
|
|
+ if filename_lower.endswith('.ddl'):
|
|
|
|
+ success_count = train_ddl_file(temp_file_path)
|
|
|
|
+
|
|
|
|
+ elif filename_lower.endswith(('.md', '.markdown')):
|
|
|
|
+ success_count = train_documentation_file(temp_file_path)
|
|
|
|
+
|
|
|
|
+ elif filename_lower.endswith('_pair.json') or filename_lower.endswith('_pairs.json'):
|
|
|
|
+ success_count = train_json_pairs_file(temp_file_path)
|
|
|
|
+
|
|
|
|
+ elif filename_lower.endswith('_pair.sql') or filename_lower.endswith('_pairs.sql'):
|
|
|
|
+ success_count = train_formatted_pairs_file(temp_file_path)
|
|
|
|
+
|
|
|
|
+ elif filename_lower.endswith('.sql'):
|
|
|
|
+ success_count = train_sql_file(temp_file_path)
|
|
|
|
+ else:
|
|
|
|
+ return jsonify(bad_request_response("不支持的文件类型"))
|
|
|
|
+
|
|
|
|
+ # 5. 刷新批处理
|
|
|
|
+ flush_training()
|
|
|
|
+
|
|
|
|
+ return jsonify(success_response(
|
|
|
|
+ response_text=f"文件上传并训练成功:处理了 {success_count} 条记录",
|
|
|
|
+ data={
|
|
|
|
+ "filename": filename,
|
|
|
|
+ "file_size": file_size,
|
|
|
|
+ "records_processed": success_count,
|
|
|
|
+ "status": "completed"
|
|
|
|
+ }
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练失败: {str(e)}")
|
|
|
|
+ return jsonify(internal_error_response(f"训练失败: {str(e)}"))
|
|
|
|
+ finally:
|
|
|
|
+ # 清理临时文件
|
|
|
|
+ if temp_file_path and os.path.exists(temp_file_path):
|
|
|
|
+ try:
|
|
|
|
+ os.unlink(temp_file_path)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f"清理临时文件失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"文件上传失败: {str(e)}")
|
|
|
|
+ return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
|
|
|
|
+
|
|
|
|
+# 复用现有训练函数的包装器
|
|
|
|
+def train_ddl_file(file_path: str) -> int:
|
|
|
|
+ """复用 run_training.py 中的 DDL 训练逻辑"""
|
|
|
|
+ success_count = 0
|
|
|
|
+ try:
|
|
|
|
+ ddl_statements = read_file_by_delimiter(file_path, ";")
|
|
|
|
+
|
|
|
|
+ for ddl in ddl_statements:
|
|
|
|
+ try:
|
|
|
|
+ train_ddl(ddl)
|
|
|
|
+ success_count += 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练DDL失败: {ddl[:50]}..., 错误: {str(e)}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"读取DDL文件失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+ return success_count
|
|
|
|
+
|
|
|
|
+def train_documentation_file(file_path: str) -> int:
|
|
|
|
+ """复用 run_training.py 中的文档训练逻辑"""
|
|
|
|
+ success_count = 0
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # 检查是否为Markdown文件
|
|
|
|
+ if file_path.lower().endswith(('.md', '.markdown')):
|
|
|
|
+ sections = read_markdown_file_by_sections(file_path)
|
|
|
|
+ for section in sections:
|
|
|
|
+ try:
|
|
|
|
+ train_documentation(section)
|
|
|
|
+ success_count += 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练文档失败: {section[:50]}..., 错误: {str(e)}")
|
|
|
|
+ else:
|
|
|
|
+ # 非Markdown文件使用传统的---分隔
|
|
|
|
+ doc_blocks = read_file_by_delimiter(file_path, "---")
|
|
|
|
+ for doc in doc_blocks:
|
|
|
|
+ try:
|
|
|
|
+ train_documentation(doc)
|
|
|
|
+ success_count += 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练文档失败: {doc[:50]}..., 错误: {str(e)}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"读取文档文件失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+ return success_count
|
|
|
|
+
|
|
|
|
+def train_json_pairs_file(file_path: str) -> int:
|
|
|
|
+ """复用 run_training.py 中的 JSON 问答对训练逻辑"""
|
|
|
|
+ try:
|
|
|
|
+ # 直接调用现有函数
|
|
|
|
+ train_json_question_sql_pairs(file_path)
|
|
|
|
+
|
|
|
|
+ # 计算成功数量
|
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
+ pairs = json.load(f)
|
|
|
|
+ return len(pairs) if isinstance(pairs, list) else 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练JSON问答对失败: {str(e)}")
|
|
|
|
+ return 0
|
|
|
|
+
|
|
|
|
+def train_formatted_pairs_file(file_path: str) -> int:
|
|
|
|
+ """复用 run_training.py 中的格式化问答对训练逻辑"""
|
|
|
|
+ try:
|
|
|
|
+ # 直接调用现有函数
|
|
|
|
+ train_formatted_question_sql_pairs(file_path)
|
|
|
|
+
|
|
|
|
+ # 计算成功数量
|
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
+ content = f.read()
|
|
|
|
+ pairs = content.split('\n\n')
|
|
|
|
+ return len([p for p in pairs if p.strip()])
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练格式化问答对失败: {str(e)}")
|
|
|
|
+ return 0
|
|
|
|
+
|
|
|
|
+def train_sql_file(file_path: str) -> int:
|
|
|
|
+ """复用 run_training.py 中的 SQL 训练逻辑"""
|
|
|
|
+ success_count = 0
|
|
|
|
+ try:
|
|
|
|
+ sql_statements = read_file_by_delimiter(file_path, ";")
|
|
|
|
+
|
|
|
|
+ for sql in sql_statements:
|
|
|
|
+ try:
|
|
|
|
+ train_sql_example(sql)
|
|
|
|
+ success_count += 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"训练SQL失败: {sql[:50]}..., 错误: {str(e)}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"读取SQL文件失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+ return success_count
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**返回格式:**
|
|
|
|
+
|
|
|
|
+**成功响应:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 200,
|
|
|
|
+ "success": true,
|
|
|
|
+ "message": "操作成功",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "文件上传并训练成功:处理了 45 条记录",
|
|
|
|
+ "filename": "training_data.sql",
|
|
|
|
+ "file_size": 12345,
|
|
|
|
+ "records_processed": 45,
|
|
|
|
+ "status": "completed"
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**错误响应:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 400,
|
|
|
|
+ "success": false,
|
|
|
|
+ "message": "请求参数错误",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "文件大小不能超过500KB",
|
|
|
|
+ "error_type": "INVALID_PARAMS",
|
|
|
|
+ "timestamp": "2025-01-15T10:30:00.000Z",
|
|
|
|
+ "can_retry": false
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**安全性和性能考虑:**
|
|
|
|
+
|
|
|
|
+1. **文件处理优化**:
|
|
|
|
+ - 使用临时文件处理,确保内存安全
|
|
|
|
+ - 文件大小限制500KB,避免内存压力
|
|
|
|
+ - 自动清理临时文件,防止磁盘空间泄露
|
|
|
|
+ - 完整的错误处理和日志记录
|
|
|
|
+
|
|
|
|
+2. **自动类型识别**:
|
|
|
|
+ - 基于文件扩展名和命名规则
|
|
|
|
+ - 智能解析不同格式内容
|
|
|
|
+ - 大小写不敏感的格式处理
|
|
|
|
+
|
|
|
|
+3. **复用成熟功能**:
|
|
|
|
+ - 直接使用 `data_pipeline.trainer.run_training` 中的成熟函数
|
|
|
|
+ - 复用 `read_file_by_delimiter` 和 `read_markdown_file_by_sections` 解析逻辑
|
|
|
|
+ - 调用 `train_json_question_sql_pairs` 和 `train_formatted_question_sql_pairs` 处理问答对
|
|
|
|
+ - 继承批处理优化机制,保持训练逻辑的一致性
|
|
|
|
+ - 避免重复造轮子,提高代码维护性
|
|
|
|
+
|
|
|
|
+4. **错误处理**:
|
|
|
|
+ - 详细的错误日志记录
|
|
|
|
+ - 逐条处理失败时的容错机制
|
|
|
|
+ - 统计处理成功和失败的记录数
|
|
|
|
+ - 临时文件清理保证不留垃圾文件
|
|
|
|
+
|
|
|
|
+5. **性能优化**:
|
|
|
|
+ - 复用现有的高效解析函数
|
|
|
|
+ - 批处理机制提高训练效率
|
|
|
|
+ - 自动刷新确保数据及时写入
|
|
|
|
+ - 避免重复开发,减少维护成本
|
|
|
|
+
|
|
|
|
+**使用建议:**
|
|
|
|
+
|
|
|
|
+1. **文件准备**:
|
|
|
|
+ - 确保文件使用UTF-8编码
|
|
|
|
+ - 控制文件大小在500KB以内
|
|
|
|
+ - 按照规定的格式和命名规则准备文件
|
|
|
|
+
|
|
|
|
+2. **文件类型选择**:
|
|
|
|
+ - DDL文件:用于训练表结构信息,复用 `train_ddl_statements` 逻辑
|
|
|
|
+ - Markdown文档:用于训练表和字段说明,复用 `train_documentation_blocks` 逻辑
|
|
|
|
+ - JSON问答对:用于训练结构化问答数据,复用 `train_json_question_sql_pairs` 逻辑
|
|
|
|
+ - 格式化问答对:用于训练文本格式问答数据,复用 `train_formatted_question_sql_pairs` 逻辑
|
|
|
|
+ - SQL文件:用于训练SQL示例,复用 `train_sql_examples` 逻辑
|
|
|
|
+
|
|
|
|
+3. **最佳实践**:
|
|
|
|
+ - 建议分批上传,每次处理一种类型的文件
|
|
|
|
+ - 上传前检查文件格式的正确性
|
|
|
|
+ - 关注返回结果中的处理统计信息
|
|
|
|
+ - 复用现有训练逻辑,确保数据处理的一致性
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+### 3.3 `/api/v0/training_data/combine` (新增的API)
|
|
|
|
+
|
|
|
|
+**功能描述:**
|
|
|
|
+- 合并 langchain_pg_embedding 表中相同 document 字段的重复记录
|
|
|
|
+- 支持按集合名称过滤合并范围
|
|
|
|
+- 提供预览模式(dry_run)和实际执行模式
|
|
|
|
+- 支持多种保留策略(first/last/by_metadata_time)
|
|
|
|
+
|
|
|
|
+**技术分析:**
|
|
|
|
+
|
|
|
|
+训练数据写入机制中,不同类型的数据在 document 字段中存储不同格式的内容:
|
|
|
|
+- SQL类型:JSON格式 `{"question": "...", "sql": "..."}`
|
|
|
|
+- DDL类型:直接存储DDL语句原文
|
|
|
|
+- Documentation类型:直接存储文档内容原文
|
|
|
|
+- Error_SQL类型:JSON格式错误SQL示例
|
|
|
|
+
|
|
|
|
+合并策略采用 document 字段精确匹配,因为:
|
|
|
|
+1. embedding 字段是向量,比较成本高且可能有微小差异
|
|
|
|
+2. document 字段存储确定性文本内容,便于精确比较
|
|
|
|
+3. 相同的 document 内容必然产生相同的训练效果
|
|
|
|
+
|
|
|
|
+**API规格:**
|
|
|
|
+
|
|
|
|
+**请求参数:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "collection_names": ["sql", "ddl", "documentation", "error_sql"],
|
|
|
|
+ "dry_run": true,
|
|
|
|
+ "keep_strategy": "first"
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**参数说明:**
|
|
|
|
+- `collection_names`: 要处理的集合名称数组,支持 ["sql", "ddl", "documentation", "error_sql"]
|
|
|
|
+- `dry_run`: 是否为预览模式,默认 true(安全)
|
|
|
|
+- `keep_strategy`: 保留策略
|
|
|
|
+ - `"first"`: 保留第一条记录(按ID排序,推荐默认)
|
|
|
|
+ - `"last"`: 保留最后一条记录(按ID排序)
|
|
|
|
+ - `"by_metadata_time"`: 按 cmetadata.createdat 时间排序保留最新记录
|
|
|
|
+
|
|
|
|
+**实现代码框架:**
|
|
|
|
+```python
|
|
|
|
+@app.route('/api/v0/training_data/combine', methods=['POST'])
|
|
|
|
+def combine_training_data():
|
|
|
|
+ try:
|
|
|
|
+ # 1. 参数验证
|
|
|
|
+ data = request.get_json()
|
|
|
|
+ if not data:
|
|
|
|
+ return jsonify(bad_request_response("请求体不能为空"))
|
|
|
|
+
|
|
|
|
+ collection_names = data.get('collection_names', [])
|
|
|
|
+ if not collection_names or not isinstance(collection_names, list):
|
|
|
|
+ return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
|
|
|
|
+
|
|
|
|
+ # 验证集合名称
|
|
|
|
+ valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
|
|
|
|
+ invalid_collections = [name for name in collection_names if name not in valid_collections]
|
|
|
|
+ if invalid_collections:
|
|
|
|
+ return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
|
|
|
|
+
|
|
|
|
+ dry_run = data.get('dry_run', True)
|
|
|
|
+ keep_strategy = data.get('keep_strategy', 'first')
|
|
|
|
+
|
|
|
|
+ if keep_strategy not in ['first', 'last', 'by_metadata_time']:
|
|
|
|
+ return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
|
|
|
|
+
|
|
|
|
+ # 2. 获取数据库连接
|
|
|
|
+ conn = get_db_connection()
|
|
|
|
+ cursor = conn.cursor()
|
|
|
|
+
|
|
|
|
+ # 3. 查找重复记录
|
|
|
|
+ duplicate_groups = []
|
|
|
|
+ total_before = 0
|
|
|
|
+ total_duplicates = 0
|
|
|
|
+ collections_stats = {}
|
|
|
|
+
|
|
|
|
+ for collection_name in collection_names:
|
|
|
|
+ # 获取集合ID
|
|
|
|
+ cursor.execute(
|
|
|
|
+ "SELECT uuid FROM langchain_pg_collection WHERE name = %s",
|
|
|
|
+ (collection_name,)
|
|
|
|
+ )
|
|
|
|
+ collection_result = cursor.fetchone()
|
|
|
|
+ if not collection_result:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ collection_id = collection_result[0]
|
|
|
|
+
|
|
|
|
+ # 统计该集合的记录数
|
|
|
|
+ cursor.execute(
|
|
|
|
+ "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
|
|
|
|
+ (collection_id,)
|
|
|
|
+ )
|
|
|
|
+ collection_before = cursor.fetchone()[0]
|
|
|
|
+ total_before += collection_before
|
|
|
|
+
|
|
|
|
+ # 查找重复记录
|
|
|
|
+ cursor.execute("""
|
|
|
|
+ SELECT document, COUNT(*) as duplicate_count,
|
|
|
|
+ array_agg(id ORDER BY %s) as record_ids
|
|
|
|
+ FROM langchain_pg_embedding
|
|
|
|
+ WHERE collection_id = %s
|
|
|
|
+ GROUP BY document
|
|
|
|
+ HAVING COUNT(*) > 1
|
|
|
|
+ """ % (
|
|
|
|
+ "id" if keep_strategy in ['first', 'last'] else "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id",
|
|
|
|
+ "%s"
|
|
|
|
+ ), (collection_id,))
|
|
|
|
+
|
|
|
|
+ collection_duplicates = 0
|
|
|
|
+ for row in cursor.fetchall():
|
|
|
|
+ document_content, duplicate_count, record_ids = row
|
|
|
|
+ collection_duplicates += duplicate_count - 1 # 减去要保留的一条
|
|
|
|
+
|
|
|
|
+ # 根据保留策略选择要保留的记录
|
|
|
|
+ if keep_strategy == 'first':
|
|
|
|
+ keep_id = record_ids[0]
|
|
|
|
+ remove_ids = record_ids[1:]
|
|
|
|
+ elif keep_strategy == 'last':
|
|
|
|
+ keep_id = record_ids[-1]
|
|
|
|
+ remove_ids = record_ids[:-1]
|
|
|
|
+ else: # by_metadata_time
|
|
|
|
+ keep_id = record_ids[0] # 已经按时间排序
|
|
|
|
+ remove_ids = record_ids[1:]
|
|
|
|
+
|
|
|
|
+ duplicate_groups.append({
|
|
|
|
+ "collection_name": collection_name,
|
|
|
|
+ "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
|
|
|
|
+ "duplicate_count": duplicate_count,
|
|
|
|
+ "kept_record_id" if not dry_run else "records_to_keep": keep_id,
|
|
|
|
+ "removed_record_ids" if not dry_run else "records_to_remove": remove_ids
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ total_duplicates += collection_duplicates
|
|
|
|
+ collections_stats[collection_name] = {
|
|
|
|
+ "before": collection_before,
|
|
|
|
+ "after": collection_before - collection_duplicates,
|
|
|
|
+ "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # 4. 执行合并操作(如果不是dry_run)
|
|
|
|
+ if not dry_run:
|
|
|
|
+ try:
|
|
|
|
+ conn.autocommit = False # 开始事务
|
|
|
|
+
|
|
|
|
+ for group in duplicate_groups:
|
|
|
|
+ remove_ids = group["removed_record_ids"]
|
|
|
|
+ if remove_ids:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
|
|
|
|
+ (remove_ids,)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ conn.commit()
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ conn.rollback()
|
|
|
|
+ return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
|
|
|
|
+ finally:
|
|
|
|
+ conn.autocommit = True
|
|
|
|
+
|
|
|
|
+ # 5. 构建响应
|
|
|
|
+ total_after = total_before - total_duplicates
|
|
|
|
+
|
|
|
|
+ summary = {
|
|
|
|
+ "total_records_before": total_before,
|
|
|
|
+ "total_records_after": total_after,
|
|
|
|
+ "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
|
|
|
|
+ "collections_stats": collections_stats
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if dry_run:
|
|
|
|
+ response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
|
|
|
|
+ data_key = "duplicate_groups"
|
|
|
|
+ else:
|
|
|
|
+ response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
|
|
|
|
+ data_key = "merged_groups"
|
|
|
|
+
|
|
|
|
+ return jsonify(success_response(
|
|
|
|
+ response_text=response_text,
|
|
|
|
+ data={
|
|
|
|
+ "dry_run": dry_run,
|
|
|
|
+ "collections_processed": collection_names,
|
|
|
|
+ "summary": summary,
|
|
|
|
+ data_key: duplicate_groups
|
|
|
|
+ }
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
|
|
|
|
+ finally:
|
|
|
|
+ if 'cursor' in locals():
|
|
|
|
+ cursor.close()
|
|
|
|
+ if 'conn' in locals():
|
|
|
|
+ conn.close()
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**返回格式(dry_run: true):**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 200,
|
|
|
|
+ "success": true,
|
|
|
|
+ "message": "操作成功",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "发现 50 条重复记录,预计删除后将从 1000 条减少到 950 条记录",
|
|
|
|
+ "dry_run": true,
|
|
|
|
+ "collections_processed": ["sql", "ddl"],
|
|
|
|
+ "summary": {
|
|
|
|
+ "total_records_before": 1000,
|
|
|
|
+ "total_records_after": 950,
|
|
|
|
+ "duplicates_to_remove": 50,
|
|
|
|
+ "collections_stats": {
|
|
|
|
+ "sql": {
|
|
|
|
+ "before": 600,
|
|
|
|
+ "after": 580,
|
|
|
|
+ "duplicates_to_remove": 20
|
|
|
|
+ },
|
|
|
|
+ "ddl": {
|
|
|
|
+ "before": 400,
|
|
|
|
+ "after": 370,
|
|
|
|
+ "duplicates_to_remove": 30
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ "duplicate_groups": [
|
|
|
|
+ {
|
|
|
|
+ "collection_name": "sql",
|
|
|
|
+ "document_content": "SELECT * FROM users WHERE id = ?",
|
|
|
|
+ "duplicate_count": 3,
|
|
|
|
+ "records_to_keep": "uuid1-sql",
|
|
|
|
+ "records_to_remove": ["uuid2-sql", "uuid3-sql"]
|
|
|
|
+ }
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**返回格式(dry_run: false):**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 200,
|
|
|
|
+ "success": true,
|
|
|
|
+ "message": "操作成功",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "成功合并重复记录,删除了 50 条重复记录,从 1000 条减少到 950 条记录",
|
|
|
|
+ "dry_run": false,
|
|
|
|
+ "collections_processed": ["sql", "ddl"],
|
|
|
|
+ "summary": {
|
|
|
|
+ "total_records_before": 1000,
|
|
|
|
+ "total_records_after": 950,
|
|
|
|
+ "duplicates_removed": 50,
|
|
|
|
+ "collections_stats": {
|
|
|
|
+ "sql": {
|
|
|
|
+ "before": 600,
|
|
|
|
+ "after": 580,
|
|
|
|
+ "duplicates_removed": 20
|
|
|
|
+ },
|
|
|
|
+ "ddl": {
|
|
|
|
+ "before": 400,
|
|
|
|
+ "after": 370,
|
|
|
|
+ "duplicates_removed": 30
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ "merged_groups": [
|
|
|
|
+ {
|
|
|
|
+ "collection_name": "sql",
|
|
|
|
+ "document_content": "SELECT * FROM users WHERE id = ?",
|
|
|
|
+ "duplicate_count": 3,
|
|
|
|
+ "kept_record_id": "uuid1-sql",
|
|
|
|
+ "removed_record_ids": ["uuid2-sql", "uuid3-sql"]
|
|
|
|
+ }
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**错误响应示例:**
|
|
|
|
+```json
|
|
|
|
+{
|
|
|
|
+ "code": 400,
|
|
|
|
+ "success": false,
|
|
|
|
+ "message": "请求参数错误",
|
|
|
|
+ "data": {
|
|
|
|
+ "response": "collection_names 参数必须是非空数组",
|
|
|
|
+ "error_type": "INVALID_PARAMS",
|
|
|
|
+ "timestamp": "2025-01-15T10:30:00.000Z",
|
|
|
|
+ "can_retry": false
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+**安全性和性能考虑:**
|
|
|
|
+
|
|
|
|
+1. **事务安全**:
|
|
|
|
+ - 整个合并过程在数据库事务中执行
|
|
|
|
+ - 发生错误时自动回滚
|
|
|
|
+ - 避免数据不一致状态
|
|
|
|
+
|
|
|
|
+2. **预览模式**:
|
|
|
|
+ - 默认启用 dry_run 模式
|
|
|
|
+ - 用户可以先查看合并计划再决定是否执行
|
|
|
|
+ - 提供详细的影响分析
|
|
|
|
+
|
|
|
|
+3. **保留策略**:
|
|
|
|
+ - 支持多种保留策略满足不同需求
|
|
|
|
+ - first 策略最简单可靠,推荐作为默认
|
|
|
|
+ - by_metadata_time 策略使用真实创建时间
|
|
|
|
+
|
|
|
|
+4. **性能优化**:
|
|
|
|
+ - 使用数组聚合函数减少数据库查询次数
|
|
|
|
+ - 批量删除操作提高效率
|
|
|
|
+ - 分集合处理避免长时间锁表
|
|
|
|
+
|
|
|
|
+5. **错误处理**:
|
|
|
|
+ - 详细的参数验证
|
|
|
|
+ - 完整的异常处理和错误信息
|
|
|
|
+ - 数据库连接的正确关闭
|
|
|
|
+
|
|
|
|
+**使用建议:**
|
|
|
|
+
|
|
|
|
+1. **推荐流程**:
|
|
|
|
+ ```bash
|
|
|
|
+ # 1. 先预览合并计划
|
|
|
|
+ POST /api/v0/training_data/combine
|
|
|
|
+ {
|
|
|
|
+ "collection_names": ["sql", "ddl"],
|
|
|
|
+ "dry_run": true,
|
|
|
|
+ "keep_strategy": "first"
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # 2. 确认无误后执行实际合并
|
|
|
|
+ POST /api/v0/training_data/combine
|
|
|
|
+ {
|
|
|
|
+ "collection_names": ["sql", "ddl"],
|
|
|
|
+ "dry_run": false,
|
|
|
|
+ "keep_strategy": "first"
|
|
|
|
+ }
|
|
|
|
+ ```
|
|
|
|
+
|
|
|
|
+2. **保留策略选择**:
|
|
|
|
+ - 一般情况:使用 `"first"` 策略
|
|
|
|
+ - 需要保留最新数据:使用 `"by_metadata_time"` 策略
|
|
|
|
+ - 测试环境:可以使用 `"last"` 策略
|
|
|
|
+
|
|
|
|
+3. **批量处理**:
|
|
|
|
+ - 大量数据时建议分批处理不同集合
|
|
|
|
+ - 可以先处理较小的集合验证效果
|
|
|
|
+ - 重要数据建议先备份
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+## 4. 补充建议
|
|
|
|
+
|
|
|
|
+### 4.1 错误处理增强
|
|
|
|
+
|
|
|
|
+建议在所有API中添加更详细的错误处理:
|
|
|
|
+
|
|
|
|
+```python
|
|
|
|
+# 通用错误处理装饰器
|
|
|
|
+def handle_api_errors(f):
|
|
|
|
+ @wraps(f)
|
|
|
|
+ def decorated_function(*args, **kwargs):
|
|
|
|
+ try:
|
|
|
|
+ return f(*args, **kwargs)
|
|
|
|
+ except ValueError as e:
|
|
|
|
+ return jsonify(bad_request_response(str(e)))
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"API错误: {str(e)}", exc_info=True)
|
|
|
|
+ return jsonify(internal_error_response(f"系统错误: {str(e)}"))
|
|
|
|
+ return decorated_function
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+### 4.2 批量操作优化
|
|
|
|
+
|
|
|
|
+对于大量数据的处理,建议引入任务队列:
|
|
|
|
+
|
|
|
|
+```python
|
|
|
|
+# 异步任务支持
|
|
|
|
+@app.route('/api/v0/training_data/upload_async', methods=['POST'])
|
|
|
|
+def upload_training_data_async():
|
|
|
|
+ # 创建后台任务
|
|
|
|
+ task_id = str(uuid.uuid4())
|
|
|
|
+ # 将任务放入队列
|
|
|
|
+ # 返回任务ID供客户端查询进度
|
|
|
|
+ return jsonify(success_response(
|
|
|
|
+ response_text="文件上传任务已创建",
|
|
|
|
+ data={"task_id": task_id}
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+@app.route('/api/v0/training_data/task_status/<task_id>', methods=['GET'])
|
|
|
|
+def get_task_status(task_id):
|
|
|
|
+ # 查询任务状态
|
|
|
|
+ pass
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+### 4.3 数据验证增强
|
|
|
|
+
|
|
|
|
+建议添加更严格的数据验证:
|
|
|
|
+
|
|
|
|
+```python
|
|
|
|
+# SQL语法验证
|
|
|
|
+def validate_sql(sql):
|
|
|
|
+ try:
|
|
|
|
+ import sqlparse
|
|
|
|
+ parsed = sqlparse.parse(sql)
|
|
|
|
+ if not parsed:
|
|
|
|
+ raise ValueError("SQL语法错误")
|
|
|
|
+ return True
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise ValueError(f"SQL验证失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+# 数据类型验证
|
|
|
|
+def validate_training_data(data_type, content):
|
|
|
|
+ if data_type == 'sql':
|
|
|
|
+ return validate_sql(content)
|
|
|
|
+ elif data_type == 'ddl':
|
|
|
|
+ return validate_ddl(content)
|
|
|
|
+ # 其他验证逻辑
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+## 5. 实施建议
|
|
|
|
+
|
|
|
|
+### 5.1 开发顺序
|
|
|
|
+
|
|
|
|
+1. **第一阶段**:实现增强的stats API(相对简单)
|
|
|
|
+2. **第二阶段**:实现update API(中等复杂度)
|
|
|
|
+3. **第三阶段**:实现upload API(最复杂,需要充分测试)
|
|
|
|
+
|
|
|
|
+### 5.2 测试策略
|
|
|
|
+
|
|
|
|
+1. **单元测试**:每个解析函数和训练函数
|
|
|
|
+2. **集成测试**:完整的API调用流程
|
|
|
|
+3. **性能测试**:大文件上传和处理
|
|
|
|
+4. **错误测试**:各种异常情况的处理
|
|
|
|
+
|
|
|
|
+### 5.3 部署注意事项
|
|
|
|
+
|
|
|
|
+1. **文件上传限制**:配置Nginx/Apache的文件大小限制
|
|
|
|
+2. **临时目录**:确保有足够的磁盘空间
|
|
|
|
+3. **数据库连接**:优化数据库连接池设置
|
|
|
|
+4. **日志监控**:添加详细的操作日志
|
|
|
|
+
|
|
|
|
+---
|
|
|
|
+
|
|
|
|
+## 6. 结论
|
|
|
|
+
|
|
|
|
+本改造方案通过充分复用现有代码,提供了更加完整和实用的解决方案:
|
|
|
|
+
|
|
|
|
+1. **update API**:简化为单条更新,采用先删除后插入策略,返回新记录ID
|
|
|
|
+2. **upload API**:复用 `data_pipeline/trainer` 模块的成熟函数,支持多种文件格式的自动解析和导入
|
|
|
|
+3. **combine API**:支持合并重复训练数据,提供预览和执行模式
|
|
|
|
+
|
|
|
|
+**主要优势:**
|
|
|
|
+
|
|
|
|
+1. **避免重复造轮子**:
|
|
|
|
+ - 直接复用 `run_training.py` 中的成熟解析函数
|
|
|
|
+ - 保持与现有训练逻辑的一致性
|
|
|
|
+ - 减少代码维护成本
|
|
|
|
+
|
|
|
|
+2. **提高代码质量**:
|
|
|
|
+ - 复用经过验证的文件处理逻辑
|
|
|
|
+ - 继承现有的错误处理机制
|
|
|
|
+ - 保持日志格式的统一性
|
|
|
|
+
|
|
|
|
+3. **性能优化**:
|
|
|
|
+ - 利用现有的批处理机制
|
|
|
|
+ - 复用高效的文件解析算法
|
|
|
|
+ - 减少开发和测试时间
|
|
|
|
+
|
|
|
|
+特别是upload API的设计,通过复用现有的trainer模块,不仅避免了重复开发,还确保了与现有系统的完美兼容性,可以满足生产环境的需求。
|