4.训练数据API改造方案.md 42 KB

unified_api.py 训练数据管理API改造方案(详细版)

1. 需求概述

本方案基于对 unified_api.py 文件的深入分析,提出以下两个API改造与新增需求:

  • 需求一:新增 /api/v0/training_data/update,支持先删除再插入的训练数据更新操作
  • 需求二:新增 /api/v0/training_data/upload,支持上传多种格式文件并自动解析入库
  • 需求三:新增 /api/v0/training_data/combine,合并相同的 langchain_pg_embedding.document 记录

2. 当前实现分析

2.1 现有API结构

unified_api.py 目前实现了4个训练数据相关的API端点:

  1. GET /api/v0/training_data/stats (lines 1205-1234) - 基础统计信息
  2. POST /api/v0/training_data/query (lines 1236-1294) - 分页查询
  3. POST /api/v0/training_data/create (lines 1296-1404) - 批量创建(最多50条)
  4. POST /api/v0/training_data/delete (lines 1406-1471) - 批量删除(最多50条)

2.2 数据库架构

表结构关系:

-- langchain_pg_collection (集合表)
CREATE TABLE langchain_pg_collection (
    uuid uuid PRIMARY KEY,
    name varchar NOT NULL UNIQUE,
    cmetadata json
);

-- langchain_pg_embedding (训练数据表)
CREATE TABLE langchain_pg_embedding (
    id varchar PRIMARY KEY,
    collection_id uuid REFERENCES langchain_pg_collection(uuid) ON DELETE CASCADE,
    embedding vector,
    document varchar,
    cmetadata jsonb
);

集合名称映射:

  • sql 集合 → SQL查询和问答对
  • ddl 集合 → 数据库表结构定义
  • documentation 集合 → 表和字段文档
  • error_sql 集合 → 错误SQL示例

3. 详细改造方案


3.1 /api/v0/training_data/update (新增的API)

功能描述:

  • 支持单条更新训练数据
  • 采用先删除后插入的策略确保数据一致性
  • 更新方式是先删除再插入,删除是以 langchain_pg_embedding.id字段为依据的
  • langchain_pg_embedding.id字段是varchar类型,它是uuid+类型的格式,比如:"50bb6b17-d5be-48ab-8125-58ab8bb076-sql"
  • 返回新创建记录的ID,便于前端更新界面

WEB UI使用流程:

  1. 用户选中一行记录,点击修改按钮
  2. 前端显示编辑表单,用户修改内容
  3. 提交时,根据这一行的id,删除这条记录
  4. 然后调用create参数,生成一条新的记录
  5. 返回的JSON中包括这条记录的新的id

API规格:

请求参数:

{
  "id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
  "training_data_type": "sql",
  "question": "如何查询所有用户?",
  "sql": "SELECT * FROM users;"
}

参数说明:

  • id: 要删除的原始记录ID(必需)
  • training_data_type: 训练数据类型,支持 "sql", "ddl", "documentation", "error_sql"
  • 其他字段根据类型不同而不同:
    • sql 类型:question(可选), sql(必需)
    • error_sql 类型:question(必需), sql(必需)
    • documentation 类型:content(必需)
    • ddl 类型:ddl(必需)

实现代码框架:

@app.route('/api/v0/training_data/update', methods=['POST'])
def training_data_update():
    """更新训练数据API - 支持单条更新,采用先删除后插入策略"""
    try:
        req = request.get_json(force=True)
        
        # 1. 参数验证
        original_id = req.get('id')
        if not original_id:
            return jsonify(bad_request_response(
                response_text="缺少必需参数:id",
                missing_params=["id"]
            )), 400
        
        training_type = req.get('training_data_type')
        if not training_type:
            return jsonify(bad_request_response(
                response_text="缺少必需参数:training_data_type",
                missing_params=["training_data_type"]
            )), 400
        
        # 2. 先删除原始记录
        try:
            success = vn.remove_training_data(original_id)
            if not success:
                return jsonify(bad_request_response(
                    response_text=f"原始记录 {original_id} 不存在或删除失败"
                )), 400
        except Exception as e:
            return jsonify(internal_error_response(
                response_text=f"删除原始记录失败: {str(e)}"
            )), 500
        
        # 3. 根据类型验证和准备新数据
        try:
            if training_type == 'sql':
                sql = req.get('sql')
                if not sql:
                    return jsonify(bad_request_response(
                        response_text="SQL字段是必需的",
                        missing_params=["sql"]
                    )), 400
                
                # SQL语法检查
                is_valid, error_msg = validate_sql_syntax(sql)
                if not is_valid:
                    return jsonify(bad_request_response(
                        response_text=f"SQL语法错误: {error_msg}"
                    )), 400
                
                question = req.get('question')
                if question:
                    training_id = vn.train(question=question, sql=sql)
                else:
                    training_id = vn.train(sql=sql)
                    
            elif training_type == 'error_sql':
                question = req.get('question')
                sql = req.get('sql')
                if not question or not sql:
                    return jsonify(bad_request_response(
                        response_text="question和sql字段都是必需的",
                        missing_params=["question", "sql"]
                    )), 400
                training_id = vn.train_error_sql(question=question, sql=sql)
                
            elif training_type == 'documentation':
                content = req.get('content')
                if not content:
                    return jsonify(bad_request_response(
                        response_text="content字段是必需的",
                        missing_params=["content"]
                    )), 400
                training_id = vn.train(documentation=content)
                
            elif training_type == 'ddl':
                ddl = req.get('ddl')
                if not ddl:
                    return jsonify(bad_request_response(
                        response_text="ddl字段是必需的",
                        missing_params=["ddl"]
                    )), 400
                training_id = vn.train(ddl=ddl)
                
            else:
                return jsonify(bad_request_response(
                    response_text=f"不支持的训练数据类型: {training_type}"
                )), 400
                
        except Exception as e:
            return jsonify(internal_error_response(
                response_text=f"创建新训练数据失败: {str(e)}"
            )), 500
        
        # 4. 获取更新后的总记录数
        current_total = get_total_training_count()
        
        return jsonify(success_response(
            response_text="训练数据更新成功",
            data={
                "original_id": original_id,
                "new_training_id": training_id,
                "type": training_type,
                "current_total_count": current_total
            }
        ))
        
    except Exception as e:
        logger.error(f"training_data_update执行失败: {str(e)}")
        return jsonify(internal_error_response(
            response_text="更新训练数据失败,请稍后重试"
        )), 500

返回格式:

成功响应:

{
  "code": 200,
  "success": true,
  "message": "操作成功",
  "data": {
    "response": "训练数据更新成功",
    "original_id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
    "new_training_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890-sql",
    "type": "sql",
    "current_total_count": 1250
  }
}

错误响应示例:

{
  "code": 400,
  "success": false,
  "message": "请求参数错误",
  "data": {
    "response": "原始记录 6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql 不存在或删除失败",
    "error_type": "INVALID_PARAMS",
    "timestamp": "2025-01-15T10:30:00.000Z",
    "can_retry": false
  }
}

错误响应示例:

{
  "code": 400,
  "success": false,
  "message": "请求参数错误",
  "data": {
    "response": "缺少必需参数:data",
    "error_type": "INVALID_PARAMS",
    "timestamp": "2025-01-15T10:30:00.000Z",
    "can_retry": false
  }
}

安全性和性能考虑:

  1. 事务安全

    • 删除和插入操作在同一个函数中执行
    • 如果插入失败,原始记录已被删除,需要回滚或重试机制
    • 建议在应用层实现重试逻辑
  2. 数据一致性

    • 先删除后插入确保ID的唯一性
    • 新记录会获得新的UUID,避免ID冲突
    • 保持训练数据类型的正确性
  3. 错误处理

    • 详细的参数验证
    • 原始记录存在性检查
    • SQL语法验证(针对sql类型)
    • 完整的异常处理和错误信息
  4. 单条处理

    • 只支持单条更新,简化逻辑
    • 减少并发冲突的可能性
    • 更精确的错误定位和处理
  5. 性能优化

    • 复用现有的训练函数
    • 避免重复的数据库连接
    • 合理的错误处理和日志记录

使用建议:

  1. 前端集成

    // 更新训练数据的示例
    const updateTrainingData = async (originalId, newData) => {
     const response = await fetch('/api/v0/training_data/update', {
       method: 'POST',
       headers: { 'Content-Type': 'application/json' },
       body: JSON.stringify({
         id: originalId,
         training_data_type: 'sql',
         question: newData.question,
         sql: newData.sql
       })
     });
         
     const result = await response.json();
     if (result.success) {
       // 更新成功,使用新的training_id更新界面
       const newId = result.data.new_training_id;
       updateUIWithNewId(newId);
     }
    };
    
    1. 错误处理
    2. 检查返回的 success 字段
    3. 处理各种错误情况(参数错误、记录不存在、SQL语法错误等)
    4. 显示详细的错误信息给用户

    5. 最佳实践

    6. 更新前验证原始记录是否存在

    7. 更新后验证新记录是否创建成功

    8. 使用返回的新ID更新前端界面


    3.2 /api/v0/training_data/upload 新增(重点设计)

    功能描述:

    • 支持上传多种格式的训练数据文件(txt/md/sql/json 或无扩展名)
    • 通过 file_type 参数指定文件类型,不再依赖文件扩展名
    • 文件大小限制500KB(约包含100-500条记录)
    • 对文件内容进行格式验证,确保符合指定类型要求
    • 复用现有trainer模块的成熟功能
    • 提供详细的处理结果反馈

    支持的文件类型:

    API 支持以下5种文件类型,通过 file_type 参数指定:

    file_type 支持扩展名 内容格式验证 复用函数 解析方式
    ddl ddl/sql/txt/无扩展名 必须包含 CREATE 语句(大小写不敏感) train_ddl_statements() 建表语句,以";"为分隔符
    markdown md/markdown 必须包含 ## 标题 train_documentation_blocks() Markdown以"##"为单位分割
    sql_pair_json json/txt/无扩展名 必须是有效JSON格式,包含question/sql字段 train_json_question_sql_pairs() JSON问答对训练
    sql_pair sql/txt/无扩展名 必须包含 Question:/SQL: 格式 train_formatted_question_sql_pairs() 格式化问答对训练
    sql sql/txt/无扩展名 必须包含 ; 分隔符 train_sql_examples() SQL示例,以";"为分隔符

    文件类型详细说明:

    1. DDL文件 (file_type: ddl):

    建表语句,必须包含CREATE语句,以";"为分隔符。

    -- 建表语句
    CREATE TABLE users (
    id SERIAL PRIMARY KEY,
    name VARCHAR(100) NOT NULL,
    email VARCHAR(255) UNIQUE
    );
    
    CREATE TABLE orders (
    id SERIAL PRIMARY KEY,
    user_id INTEGER REFERENCES users(id),
    amount DECIMAL(10,2)
    );
    

2. Markdown文档 (file_type: markdown):

必须包含"##"标题,以"##"为单位进行分割:

## users表
用户基本信息表,存储系统用户数据。
字段说明
- id: 用户唯一标识
- name: 用户姓名
- email: 电子邮箱

## orders表
订单信息表,记录用户购买记录。

3. JSON问答对 (file_type: sql_pair_json):

必须是有效JSON格式,"question"/"sql" 大小写不敏感。

[
  {
    "question": "如何查询活跃用户数量?",
    "sql": "SELECT COUNT(*) FROM users WHERE active = true;"
  },
  {
    "question": "如何统计每个分类的产品数量?",
    "sql": "SELECT category, COUNT(*) FROM products GROUP BY category;"
  }
]

4. 格式化问答对 (file_type: sql_pair):

必须包含Question/SQL格式,大小写不敏感,问答对之间用空行分隔。

Question: 如何查询观影人数超过 10000 的热门电影?
SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;

Question: 如何统计每个分类的产品数量?
SQL: SELECT category, COUNT(*) FROM products GROUP BY category;

5. SQL示例 (file_type: sql):

纯SQL语句,必须包含";"分隔符,系统会自动生成对应问题。

SELECT COUNT(*) FROM users WHERE active = true;
SELECT category, COUNT(*) FROM products GROUP BY category;

API规格:

请求参数:

  • file: 上传的文件(multipart/form-data)
  • file_type: 文件类型(必需),可选值:ddlmarkdownsql_pair_jsonsql_pairsql
  • 文件大小限制:最大500KB
  • 文件编码:UTF-8
  • 支持的文件扩展名:根据 file_type 动态确定(详见上表)

文件扩展名验证规则:

系统根据 file_type 参数动态确定允许的文件扩展名:

  1. DDL类型:允许 .ddl.sql.txt 扩展名或无扩展名
  2. Markdown类型:仅允许 .md.markdown 扩展名(不支持无扩展名)
  3. SQL_PAIR_JSON类型:允许 .json.txt 扩展名或无扩展名
  4. SQL_PAIR类型:允许 .sql.txt 扩展名或无扩展名
  5. SQL类型:允许 .sql.txt 扩展名或无扩展名

扩展名验证错误示例:

文件类型 ddl 不支持的文件扩展名:pdf,支持的扩展名:ddl, sql, txt 或无扩展名
文件类型 markdown 不支持的文件扩展名:txt,支持的扩展名:md, markdown

内容格式验证规则:

  1. DDL类型验证

    • 文件内容必须包含 CREATE 关键字(大小写不敏感)
    • 验证失败返回:"文件内容不符合DDL格式,必须包含CREATE语句"
  2. Markdown类型验证

    • 文件内容必须包含 ## 标题标记
    • 验证失败返回:"文件内容不符合Markdown格式,必须包含##标题"
  3. SQL_PAIR_JSON类型验证

    • 文件必须是有效的JSON格式
    • JSON内容必须包含question和sql字段
    • 验证失败返回:"文件内容不符合JSON问答对格式"
  4. SQL_PAIR类型验证

    • 文件内容必须包含 Question:SQL: 标记(大小写不敏感)
    • 验证失败返回:"文件内容不符合问答对格式,必须包含Question:和SQL:"
  5. SQL类型验证

    • 文件内容必须包含 ; 分隔符
    • 验证失败返回:"文件内容不符合SQL格式,必须包含;分隔符"

实现代码框架:

import json
import re
from werkzeug.utils import secure_filename
from data_pipeline.trainer.run_training import (
    train_ddl_statements,
    train_documentation_blocks, 
    train_json_question_sql_pairs,
    train_formatted_question_sql_pairs,
    train_sql_examples
)
from common.result import success_response, bad_request_response, internal_error_response

def get_allowed_extensions(file_type: str) -> list:
    """根据文件类型返回允许的扩展名"""
    type_specific_extensions = {
        'ddl': ['ddl', 'sql', 'txt', ''],  # 支持无扩展名
        'markdown': ['md', 'markdown'],    # 不支持无扩展名
        'sql_pair_json': ['json', 'txt', ''],  # 支持无扩展名
        'sql_pair': ['sql', 'txt', ''],   # 支持无扩展名
        'sql': ['sql', 'txt', '']         # 支持无扩展名
    }
    
    return type_specific_extensions.get(file_type, [])

@app.route('/api/v0/training_data/upload', methods=['POST'])
def upload_training_data():
    """上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
    try:
        # 1. 参数验证
        if 'file' not in request.files:
            return jsonify(bad_request_response("未提供文件"))
        
        file = request.files['file']
        if file.filename == '':
            return jsonify(bad_request_response("未选择文件"))
        
        # 获取file_type参数
        file_type = request.form.get('file_type')
        if not file_type:
            return jsonify(bad_request_response("缺少必需参数:file_type"))
        
        # 验证file_type参数
        valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
        if file_type not in valid_file_types:
            return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
        
        # 2. 文件大小验证 (500KB)
        file.seek(0, 2)
        file_size = file.tell()
        file.seek(0)
        
        if file_size > 500 * 1024:  # 500KB
            return jsonify(bad_request_response("文件大小不能超过500KB"))
        
        # 3. 验证文件扩展名(基于file_type)
        filename = secure_filename(file.filename)
        allowed_extensions = get_allowed_extensions(file_type)
        file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
        
        if file_ext not in allowed_extensions:
            # 构建友好的错误信息
            non_empty_extensions = [ext for ext in allowed_extensions if ext]
            if '' in allowed_extensions:
                ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
            else:
                ext_message = ', '.join(non_empty_extensions)
            return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
        
        # 4. 读取文件内容并验证格式
        file.seek(0)
        content = file.read().decode('utf-8')
        
        # 格式验证
        validation_result = validate_file_content(content, file_type)
        if not validation_result['valid']:
            return jsonify(bad_request_response(validation_result['error']))
        
        # 5. 创建临时文件(复用现有函数需要文件路径)
        temp_file_path = None
        
        try:
            with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
                tmp_file.write(content)
                temp_file_path = tmp_file.name
            
            # 6. 根据文件类型调用现有的训练函数
            if file_type == 'ddl':
                train_ddl_statements(temp_file_path)
                
            elif file_type == 'markdown':
                train_documentation_blocks(temp_file_path)
                
            elif file_type == 'sql_pair_json':
                train_json_question_sql_pairs(temp_file_path)
                
            elif file_type == 'sql_pair':
                train_formatted_question_sql_pairs(temp_file_path)
                
            elif file_type == 'sql':
                train_sql_examples(temp_file_path)
            
            return jsonify(success_response(
                response_text=f"文件上传并训练成功:{filename}",
                data={
                    "filename": filename,
                    "file_type": file_type,
                    "file_size": file_size,
                    "status": "completed"
                }
            ))
            
        except Exception as e:
            logger.error(f"训练失败: {str(e)}")
            return jsonify(internal_error_response(f"训练失败: {str(e)}"))
        finally:
            # 清理临时文件
            if temp_file_path and os.path.exists(temp_file_path):
                try:
                    os.unlink(temp_file_path)
                except Exception as e:
                    logger.warning(f"清理临时文件失败: {str(e)}")
        
    except Exception as e:
        logger.error(f"文件上传失败: {str(e)}")
        return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))

def validate_file_content(content: str, file_type: str) -> dict:
    """验证文件内容格式"""
    try:
        if file_type == 'ddl':
            # 检查是否包含CREATE语句
            if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
                return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
            
        elif file_type == 'markdown':
            # 检查是否包含##标题
            if '##' not in content:
                return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
            
        elif file_type == 'sql_pair_json':
            # 检查是否为有效JSON
            try:
                data = json.loads(content)
                if not isinstance(data, list) or not data:
                    return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
                
                # 检查是否包含question和sql字段
                for item in data:
                    if not isinstance(item, dict):
                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
                    
                    has_question = any(key.lower() == 'question' for key in item.keys())
                    has_sql = any(key.lower() == 'sql' for key in item.keys())
                    
                    if not has_question or not has_sql:
                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
                        
            except json.JSONDecodeError:
                return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
            
        elif file_type == 'sql_pair':
            # 检查是否包含Question:和SQL:
            if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
            if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
            
        elif file_type == 'sql':
            # 检查是否包含;分隔符
            if ';' not in content:
                return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
        
        return {'valid': True}
        
    except Exception as e:
        return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}

返回格式:

成功响应:

{
  "code": 200,
  "success": true,
  "message": "操作成功",
  "data": {
    "response": "文件上传并训练成功:training_data.sql",
    "filename": "training_data.sql",
    "file_type": "sql",
    "file_size": 12345,
    "status": "completed"
  }
}

错误响应:

{
  "code": 400,
  "success": false,
  "message": "请求参数错误",
  "data": {
    "response": "文件内容不符合SQL格式,必须包含;分隔符",
    "error_type": "INVALID_PARAMS",
    "timestamp": "2025-01-15T10:30:00.000Z",
    "can_retry": false
  }
}

安全性和性能考虑:

  1. 文件处理优化

    • 使用临时文件处理,确保内存安全
    • 文件大小限制500KB,避免内存压力
    • 自动清理临时文件,防止磁盘空间泄露
    • 完整的错误处理和日志记录
  2. 类型验证机制

    • 基于 file_type 参数动态确定允许的扩展名
    • 对文件内容进行格式验证
    • 大小写不敏感的格式处理
    • 详细的验证错误信息,包含具体的文件类型和支持的扩展名
  3. 复用成熟功能

    • 直接使用 data_pipeline.trainer.run_training 中的成熟函数
    • 复用 train_ddl_statementstrain_documentation_blocks 等处理逻辑
    • 调用 train_json_question_sql_pairstrain_formatted_question_sql_pairs 处理问答对
    • 继承批处理优化机制,保持训练逻辑的一致性
    • 避免重复造轮子,提高代码维护性
  4. 错误处理

    • 详细的错误日志记录
    • 文件格式验证失败时的明确提示
    • 基于文件类型的精确扩展名验证和友好提示
    • 临时文件清理保证不留垃圾文件
  5. 性能优化

    • 复用现有的高效解析函数
    • 批处理机制提高训练效率
    • 避免重复开发,减少维护成本

使用建议:

  1. 文件准备

    • 确保文件使用UTF-8编码
    • 控制文件大小在500KB以内
    • 按照指定的格式准备文件内容
  2. 文件类型选择

    • ddl:用于训练表结构信息,支持 .ddl/.sql/.txt 扩展名或无扩展名,必须包含CREATE语句
    • markdown:用于训练表和字段说明,仅支持 .md/.markdown 扩展名,必须包含##标题
    • sql_pair_json:用于训练结构化问答数据,支持 .json/.txt 扩展名或无扩展名,必须是有效JSON格式
    • sql_pair:用于训练文本格式问答数据,支持 .sql/.txt 扩展名或无扩展名,必须包含Question:/SQL:格式
    • sql:用于训练SQL示例,支持 .sql/.txt 扩展名或无扩展名,必须包含;分隔符,系统会自动生成问题
  3. 最佳实践

    • 建议分批上传,每次处理一种类型的文件
    • 上传前检查文件格式的正确性
    • 明确指定正确的 file_type 参数
    • 关注返回结果中的处理状态信息

Postman测试示例:

POST /api/v0/training_data/upload
Content-Type: multipart/form-data

form-data:
- file: [选择文件]
- file_type: sql_pair

示例文件内容:
Question: 如何查询观影人数超过 10000 的热门电影?
SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;

Question: 如何统计每个分类的产品数量?
SQL: SELECT category, COUNT(*) FROM products GROUP BY category;

3.3 /api/v0/training_data/combine (新增的API)

功能描述:

  • 合并 langchain_pg_embedding 表中相同 document 字段的重复记录
  • 支持按集合名称过滤合并范围
  • 提供预览模式(dry_run)和实际执行模式
  • 支持多种保留策略(first/last/by_metadata_time)

技术分析:

训练数据写入机制中,不同类型的数据在 document 字段中存储不同格式的内容:

  • SQL类型:JSON格式 {"question": "...", "sql": "..."}
  • DDL类型:直接存储DDL语句原文
  • Documentation类型:直接存储文档内容原文
  • Error_SQL类型:JSON格式错误SQL示例

合并策略采用 document 字段精确匹配,因为:

  1. embedding 字段是向量,比较成本高且可能有微小差异
  2. document 字段存储确定性文本内容,便于精确比较
  3. 相同的 document 内容必然产生相同的训练效果

API规格:

请求参数:

{
  "collection_names": ["sql", "ddl", "documentation", "error_sql"],
  "dry_run": true,
  "keep_strategy": "first"
}

参数说明:

  • collection_names: 要处理的集合名称数组,支持 ["sql", "ddl", "documentation", "error_sql"]
  • dry_run: 是否为预览模式,默认 true(安全)
  • keep_strategy: 保留策略
    • "first": 保留第一条记录(按ID排序,推荐默认)
    • "last": 保留最后一条记录(按ID排序)
    • "by_metadata_time": 按 cmetadata.createdat 时间排序保留最新记录

实现代码框架:

@app.route('/api/v0/training_data/combine', methods=['POST'])
def combine_training_data():
    try:
        # 1. 参数验证
        data = request.get_json()
        if not data:
            return jsonify(bad_request_response("请求体不能为空"))
        
        collection_names = data.get('collection_names', [])
        if not collection_names or not isinstance(collection_names, list):
            return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
        
        # 验证集合名称
        valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
        invalid_collections = [name for name in collection_names if name not in valid_collections]
        if invalid_collections:
            return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
        
        dry_run = data.get('dry_run', True)
        keep_strategy = data.get('keep_strategy', 'first')
        
        if keep_strategy not in ['first', 'last', 'by_metadata_time']:
            return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
        
        # 2. 获取数据库连接
        conn = get_db_connection()
        cursor = conn.cursor()
        
        # 3. 查找重复记录
        duplicate_groups = []
        total_before = 0
        total_duplicates = 0
        collections_stats = {}
        
        for collection_name in collection_names:
            # 获取集合ID
            cursor.execute(
                "SELECT uuid FROM langchain_pg_collection WHERE name = %s",
                (collection_name,)
            )
            collection_result = cursor.fetchone()
            if not collection_result:
                continue
            
            collection_id = collection_result[0]
            
            # 统计该集合的记录数
            cursor.execute(
                "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
                (collection_id,)
            )
            collection_before = cursor.fetchone()[0]
            total_before += collection_before
            
            # 查找重复记录
            cursor.execute("""
                SELECT document, COUNT(*) as duplicate_count, 
                       array_agg(id ORDER BY %s) as record_ids
                FROM langchain_pg_embedding 
                WHERE collection_id = %s 
                GROUP BY document 
                HAVING COUNT(*) > 1
            """ % (
                "id" if keep_strategy in ['first', 'last'] else "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id",
                "%s"
            ), (collection_id,))
            
            collection_duplicates = 0
            for row in cursor.fetchall():
                document_content, duplicate_count, record_ids = row
                collection_duplicates += duplicate_count - 1  # 减去要保留的一条
                
                # 根据保留策略选择要保留的记录
                if keep_strategy == 'first':
                    keep_id = record_ids[0]
                    remove_ids = record_ids[1:]
                elif keep_strategy == 'last':
                    keep_id = record_ids[-1]
                    remove_ids = record_ids[:-1]
                else:  # by_metadata_time
                    keep_id = record_ids[0]  # 已经按时间排序
                    remove_ids = record_ids[1:]
                
                duplicate_groups.append({
                    "collection_name": collection_name,
                    "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
                    "duplicate_count": duplicate_count,
                    "kept_record_id" if not dry_run else "records_to_keep": keep_id,
                    "removed_record_ids" if not dry_run else "records_to_remove": remove_ids
                })
            
            total_duplicates += collection_duplicates
            collections_stats[collection_name] = {
                "before": collection_before,
                "after": collection_before - collection_duplicates,
                "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
            }
        
        # 4. 执行合并操作(如果不是dry_run)
        if not dry_run:
            try:
                conn.autocommit = False  # 开始事务
                
                for group in duplicate_groups:
                    remove_ids = group["removed_record_ids"]
                    if remove_ids:
                        cursor.execute(
                            "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
                            (remove_ids,)
                        )
                
                conn.commit()
                
            except Exception as e:
                conn.rollback()
                return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
            finally:
                conn.autocommit = True
        
        # 5. 构建响应
        total_after = total_before - total_duplicates
        
        summary = {
            "total_records_before": total_before,
            "total_records_after": total_after,
            "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
            "collections_stats": collections_stats
        }
        
        if dry_run:
            response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
            data_key = "duplicate_groups"
        else:
            response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
            data_key = "merged_groups"
        
        return jsonify(success_response(
            response_text=response_text,
            data={
                "dry_run": dry_run,
                "collections_processed": collection_names,
                "summary": summary,
                data_key: duplicate_groups
            }
        ))
        
    except Exception as e:
        return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
    finally:
        if 'cursor' in locals():
            cursor.close()
        if 'conn' in locals():
            conn.close()

返回格式(dry_run: true):

{
  "code": 200,
  "success": true,
  "message": "操作成功",
  "data": {
    "response": "发现 50 条重复记录,预计删除后将从 1000 条减少到 950 条记录",
    "dry_run": true,
    "collections_processed": ["sql", "ddl"],
    "summary": {
      "total_records_before": 1000,
      "total_records_after": 950,
      "duplicates_to_remove": 50,
      "collections_stats": {
        "sql": {
          "before": 600,
          "after": 580,
          "duplicates_to_remove": 20
        },
        "ddl": {
          "before": 400,
          "after": 370,
          "duplicates_to_remove": 30
        }
      }
    },
    "duplicate_groups": [
      {
        "collection_name": "sql",
        "document_content": "SELECT * FROM users WHERE id = ?",
        "duplicate_count": 3,
        "records_to_keep": "uuid1-sql",
        "records_to_remove": ["uuid2-sql", "uuid3-sql"]
      }
    ]
  }
}

返回格式(dry_run: false):

{
  "code": 200,
  "success": true,
  "message": "操作成功",
  "data": {
    "response": "成功合并重复记录,删除了 50 条重复记录,从 1000 条减少到 950 条记录",
    "dry_run": false,
    "collections_processed": ["sql", "ddl"],
    "summary": {
      "total_records_before": 1000,
      "total_records_after": 950,
      "duplicates_removed": 50,
      "collections_stats": {
        "sql": {
          "before": 600,
          "after": 580,
          "duplicates_removed": 20
        },
        "ddl": {
          "before": 400,
          "after": 370,
          "duplicates_removed": 30
        }
      }
    },
    "merged_groups": [
      {
        "collection_name": "sql",
        "document_content": "SELECT * FROM users WHERE id = ?",
        "duplicate_count": 3,
        "kept_record_id": "uuid1-sql",
        "removed_record_ids": ["uuid2-sql", "uuid3-sql"]
      }
    ]
  }
}

错误响应示例:

{
  "code": 400,
  "success": false,
  "message": "请求参数错误",
  "data": {
    "response": "collection_names 参数必须是非空数组",
    "error_type": "INVALID_PARAMS",
    "timestamp": "2025-01-15T10:30:00.000Z",
    "can_retry": false
  }
}

安全性和性能考虑:

  1. 事务安全

    • 整个合并过程在数据库事务中执行
    • 发生错误时自动回滚
    • 避免数据不一致状态
  2. 预览模式

    • 默认启用 dry_run 模式
    • 用户可以先查看合并计划再决定是否执行
    • 提供详细的影响分析
  3. 保留策略

    • 支持多种保留策略满足不同需求
    • first 策略最简单可靠,推荐作为默认
    • by_metadata_time 策略使用真实创建时间
  4. 性能优化

    • 使用数组聚合函数减少数据库查询次数
    • 批量删除操作提高效率
    • 分集合处理避免长时间锁表
  5. 错误处理

    • 详细的参数验证
    • 完整的异常处理和错误信息
    • 数据库连接的正确关闭

使用建议:

  1. 推荐流程: ```bash

    1. 先预览合并计划

    POST /api/v0/training_data/combine { "collection_names": ["sql", "ddl"], "dry_run": true, "keep_strategy": "first" }

# 2. 确认无误后执行实际合并 POST /api/v0/training_data/combine {

 "collection_names": ["sql", "ddl"],
 "dry_run": false,
 "keep_strategy": "first"

}


2. **保留策略选择**:
   - 一般情况:使用 `"first"` 策略
   - 需要保留最新数据:使用 `"by_metadata_time"` 策略
   - 测试环境:可以使用 `"last"` 策略

3. **批量处理**:
   - 大量数据时建议分批处理不同集合
   - 可以先处理较小的集合验证效果
   - 重要数据建议先备份

---

## 4. 补充建议

### 4.1 错误处理增强

建议在所有API中添加更详细的错误处理:

```python
# 通用错误处理装饰器
def handle_api_errors(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except ValueError as e:
            return jsonify(bad_request_response(str(e)))
        except Exception as e:
            logger.error(f"API错误: {str(e)}", exc_info=True)
            return jsonify(internal_error_response(f"系统错误: {str(e)}"))
    return decorated_function

4.2 批量操作优化

对于大量数据的处理,建议引入任务队列:

# 异步任务支持
@app.route('/api/v0/training_data/upload_async', methods=['POST'])
def upload_training_data_async():
    # 创建后台任务
    task_id = str(uuid.uuid4())
    # 将任务放入队列
    # 返回任务ID供客户端查询进度
    return jsonify(success_response(
        response_text="文件上传任务已创建",
        data={"task_id": task_id}
    ))

@app.route('/api/v0/training_data/task_status/<task_id>', methods=['GET'])
def get_task_status(task_id):
    # 查询任务状态
    pass

4.3 数据验证增强

建议添加更严格的数据验证:

# SQL语法验证
def validate_sql(sql):
    try:
        import sqlparse
        parsed = sqlparse.parse(sql)
        if not parsed:
            raise ValueError("SQL语法错误")
        return True
    except Exception as e:
        raise ValueError(f"SQL验证失败: {str(e)}")

# 数据类型验证
def validate_training_data(data_type, content):
    if data_type == 'sql':
        return validate_sql(content)
    elif data_type == 'ddl':
        return validate_ddl(content)
    # 其他验证逻辑

5. 实施建议

5.1 开发顺序

  1. 第一阶段:实现增强的stats API(相对简单)
  2. 第二阶段:实现update API(中等复杂度)
  3. 第三阶段:实现upload API(最复杂,需要充分测试)

5.2 测试策略

  1. 单元测试:每个解析函数和训练函数
  2. 集成测试:完整的API调用流程
  3. 性能测试:大文件上传和处理
  4. 错误测试:各种异常情况的处理

5.3 部署注意事项

  1. 文件上传限制:配置Nginx/Apache的文件大小限制
  2. 临时目录:确保有足够的磁盘空间
  3. 数据库连接:优化数据库连接池设置
  4. 日志监控:添加详细的操作日志

6. 结论

本改造方案通过充分复用现有代码,提供了更加完整和实用的解决方案:

  1. update API:简化为单条更新,采用先删除后插入策略,返回新记录ID
  2. upload API:复用 data_pipeline/trainer 模块的成熟函数,支持多种文件格式的自动解析和导入
  3. combine API:支持合并重复训练数据,提供预览和执行模式

主要优势:

  1. 避免重复造轮子

    • 直接复用 run_training.py 中的成熟解析函数
    • 保持与现有训练逻辑的一致性
    • 减少代码维护成本
  2. 提高代码质量

    • 复用经过验证的文件处理逻辑
    • 继承现有的错误处理机制
    • 保持日志格式的统一性
  3. 性能优化

    • 利用现有的批处理机制
    • 复用高效的文件解析算法
    • 减少开发和测试时间

特别是upload API的设计,通过复用现有的trainer模块,不仅避免了重复开发,还确保了与现有系统的完美兼容性,可以满足生产环境的需求。