Explorar o código

修复training_data的相关api迁移到unified_api.py之后的问题,准备增加training_data的更新和批量加载,以及合并的api.

wangxq hai 1 mes
pai
achega
707ca80f27
Modificáronse 3 ficheiros con 1471 adicións e 85 borrados
  1. 44 19
      data_pipeline/trainer/run_training.py
  2. 1171 0
      docs/4.训练数据API改造方案.md
  3. 256 66
      unified_api.py

+ 44 - 19
data_pipeline/trainer/run_training.py

@@ -198,46 +198,53 @@ def train_formatted_question_sql_pairs(formatted_file):
     # 按双空行分割不同的问答对
     # 使用更精确的分隔符,避免误识别
     pairs = []
-    blocks = content.split("\n\nQuestion:")
+    # 使用大小写不敏感的正则表达式来分割
+    import re
+    blocks = re.split(r'\n\n(?=question\s*:)', content, flags=re.IGNORECASE)
     
     # 处理第一块(可能没有前导的"\n\nQuestion:")
     first_block = blocks[0]
-    if first_block.strip().startswith("Question:"):
+    if re.search(r'^\s*question\s*:', first_block.strip(), re.IGNORECASE):
         pairs.append(first_block.strip())
-    elif "Question:" in first_block:
+    elif re.search(r'question\s*:', first_block, re.IGNORECASE):
         # 处理文件开头没有Question:的情况
-        question_start = first_block.find("Question:")
-        pairs.append(first_block[question_start:].strip())
+        question_match = re.search(r'question\s*:', first_block, re.IGNORECASE)
+        pairs.append(first_block[question_match.start():].strip())
     
     # 处理其余块
     for block in blocks[1:]:
-        pairs.append("Question:" + block.strip())
+        pairs.append(block.strip())
     
     # 处理每个问答对
     successfully_processed = 0
     for idx, pair in enumerate(pairs, start=1):
         try:
-            if "Question:" not in pair or "SQL:" not in pair:
+            # 使用大小写不敏感的匹配
+            question_match = re.search(r'question\s*:', pair, re.IGNORECASE)
+            sql_match = re.search(r'sql\s*:', pair, re.IGNORECASE)
+            
+            if not question_match or not sql_match:
                 print(f" 跳过不符合格式的对 #{idx}")
                 continue
-                
-            # 提取问题部分
-            question_start = pair.find("Question:") + len("Question:")
-            sql_start = pair.find("SQL:", question_start)
             
-            if sql_start == -1:
+            # 确保SQL在Question之后
+            if sql_match.start() <= question_match.end():
                 print(f" SQL部分未找到,跳过对 #{idx}")
                 continue
                 
+            # 提取问题部分
+            question_start = question_match.end()
+            sql_start = sql_match.start()
+            
             question = pair[question_start:sql_start].strip()
             
             # 提取SQL部分(支持多行)
-            sql_part = pair[sql_start + len("SQL:"):].strip()
+            sql_part = pair[sql_match.end():].strip()
             
             # 检查是否存在下一个Question标记(防止解析错误)
-            next_question = pair.find("Question:", sql_start)
-            if next_question != -1:
-                sql_part = pair[sql_start + len("SQL:"):next_question].strip()
+            next_question_match = re.search(r'question\s*:', pair[sql_match.end():], re.IGNORECASE)
+            if next_question_match:
+                sql_part = pair[sql_match.end():sql_match.end() + next_question_match.start()].strip()
             
             if not question or not sql_part:
                 print(f" 问题或SQL为空,跳过对 #{idx}")
@@ -280,12 +287,30 @@ def train_json_question_sql_pairs(json_file):
         for idx, pair in enumerate(data, start=1):
             try:
                 # 检查问答对格式
-                if not isinstance(pair, dict) or "question" not in pair or "sql" not in pair:
+                if not isinstance(pair, dict):
+                    print(f" 跳过不符合格式的对 #{idx}")
+                    continue
+                
+                # 大小写不敏感地查找question和sql键
+                question_key = None
+                sql_key = None
+                question_value = None
+                sql_value = None
+                
+                for key, value in pair.items():
+                    if key.lower() == "question":
+                        question_key = key
+                        question_value = value
+                    elif key.lower() == "sql":
+                        sql_key = key
+                        sql_value = value
+                
+                if question_key is None or sql_key is None:
                     print(f" 跳过不符合格式的对 #{idx}")
                     continue
                 
-                question = pair["question"].strip()
-                sql = pair["sql"].strip()
+                question = str(question_value).strip()
+                sql = str(sql_value).strip()
                 
                 if not question or not sql:
                     print(f" 问题或SQL为空,跳过对 #{idx}")

+ 1171 - 0
docs/4.训练数据API改造方案.md

@@ -0,0 +1,1171 @@
+# unified_api.py 训练数据管理API改造方案(详细版)
+
+## 1. 需求概述
+
+本方案基于对 `unified_api.py` 文件的深入分析,提出以下两个API改造与新增需求:
+
+- **需求一**:新增 `/api/v0/training_data/update`,支持先删除再插入的训练数据更新操作
+- **需求二**:新增 `/api/v0/training_data/upload`,支持上传多种格式文件并自动解析入库
+- **需求三**:新增 `/api/v0/training_data/combine`,合并相同的 langchain_pg_embedding.document 记录
+
+---
+
+## 2. 当前实现分析
+
+### 2.1 现有API结构
+
+`unified_api.py` 目前实现了4个训练数据相关的API端点:
+
+1. **`GET /api/v0/training_data/stats`** (lines 1205-1234) - 基础统计信息
+2. **`POST /api/v0/training_data/query`** (lines 1236-1294) - 分页查询
+3. **`POST /api/v0/training_data/create`** (lines 1296-1404) - 批量创建(最多50条)
+4. **`POST /api/v0/training_data/delete`** (lines 1406-1471) - 批量删除(最多50条)
+
+### 2.2 数据库架构
+
+**表结构关系:**
+```sql
+-- langchain_pg_collection (集合表)
+CREATE TABLE langchain_pg_collection (
+    uuid uuid PRIMARY KEY,
+    name varchar NOT NULL UNIQUE,
+    cmetadata json
+);
+
+-- langchain_pg_embedding (训练数据表)
+CREATE TABLE langchain_pg_embedding (
+    id varchar PRIMARY KEY,
+    collection_id uuid REFERENCES langchain_pg_collection(uuid) ON DELETE CASCADE,
+    embedding vector,
+    document varchar,
+    cmetadata jsonb
+);
+```
+
+**集合名称映射:**
+- `sql` 集合 → SQL查询和问答对
+- `ddl` 集合 → 数据库表结构定义
+- `documentation` 集合 → 表和字段文档
+- `error_sql` 集合 → 错误SQL示例
+
+---
+
+## 3. 详细改造方案
+
+---
+
+### 3.1 `/api/v0/training_data/update` (新增的API)
+
+**功能描述:**
+- 支持单条更新训练数据
+- 采用先删除后插入的策略确保数据一致性
+- 更新方式是先删除再插入,删除是以 langchain_pg_embedding.id字段为依据的
+- langchain_pg_embedding.id字段是varchar类型,它是uuid+类型的格式,比如:"50bb6b17-d5be-48ab-8125-58ab8bb076e8-sql"
+- 返回新创建记录的ID,便于前端更新界面
+
+**WEB UI使用流程:**
+1. 用户选中一行记录,点击修改按钮
+2. 前端显示编辑表单,用户修改内容
+3. 提交时,根据这一行的id,删除这条记录
+4. 然后调用create参数,生成一条新的记录
+5. 返回的JSON中包括这条记录的新的id
+
+**API规格:**
+
+**请求参数:**
+```json
+{
+  "id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
+  "training_data_type": "sql",
+  "question": "如何查询所有用户?",
+  "sql": "SELECT * FROM users;"
+}
+```
+
+**参数说明:**
+- `id`: 要删除的原始记录ID(必需)
+- `training_data_type`: 训练数据类型,支持 "sql", "ddl", "documentation", "error_sql"
+- 其他字段根据类型不同而不同:
+  - `sql` 类型:`question`(可选), `sql`(必需)
+  - `error_sql` 类型:`question`(必需), `sql`(必需)
+  - `documentation` 类型:`content`(必需)
+  - `ddl` 类型:`ddl`(必需)
+
+**实现代码框架:**
+```python
+@app.route('/api/v0/training_data/update', methods=['POST'])
+def training_data_update():
+    """更新训练数据API - 支持单条更新,采用先删除后插入策略"""
+    try:
+        req = request.get_json(force=True)
+        
+        # 1. 参数验证
+        original_id = req.get('id')
+        if not original_id:
+            return jsonify(bad_request_response(
+                response_text="缺少必需参数:id",
+                missing_params=["id"]
+            )), 400
+        
+        training_type = req.get('training_data_type')
+        if not training_type:
+            return jsonify(bad_request_response(
+                response_text="缺少必需参数:training_data_type",
+                missing_params=["training_data_type"]
+            )), 400
+        
+        # 2. 先删除原始记录
+        try:
+            success = vn.remove_training_data(original_id)
+            if not success:
+                return jsonify(bad_request_response(
+                    response_text=f"原始记录 {original_id} 不存在或删除失败"
+                )), 400
+        except Exception as e:
+            return jsonify(internal_error_response(
+                response_text=f"删除原始记录失败: {str(e)}"
+            )), 500
+        
+        # 3. 根据类型验证和准备新数据
+        try:
+            if training_type == 'sql':
+                sql = req.get('sql')
+                if not sql:
+                    return jsonify(bad_request_response(
+                        response_text="SQL字段是必需的",
+                        missing_params=["sql"]
+                    )), 400
+                
+                # SQL语法检查
+                is_valid, error_msg = validate_sql_syntax(sql)
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=f"SQL语法错误: {error_msg}"
+                    )), 400
+                
+                question = req.get('question')
+                if question:
+                    training_id = vn.train(question=question, sql=sql)
+                else:
+                    training_id = vn.train(sql=sql)
+                    
+            elif training_type == 'error_sql':
+                question = req.get('question')
+                sql = req.get('sql')
+                if not question or not sql:
+                    return jsonify(bad_request_response(
+                        response_text="question和sql字段都是必需的",
+                        missing_params=["question", "sql"]
+                    )), 400
+                training_id = vn.train_error_sql(question=question, sql=sql)
+                
+            elif training_type == 'documentation':
+                content = req.get('content')
+                if not content:
+                    return jsonify(bad_request_response(
+                        response_text="content字段是必需的",
+                        missing_params=["content"]
+                    )), 400
+                training_id = vn.train(documentation=content)
+                
+            elif training_type == 'ddl':
+                ddl = req.get('ddl')
+                if not ddl:
+                    return jsonify(bad_request_response(
+                        response_text="ddl字段是必需的",
+                        missing_params=["ddl"]
+                    )), 400
+                training_id = vn.train(ddl=ddl)
+                
+            else:
+                return jsonify(bad_request_response(
+                    response_text=f"不支持的训练数据类型: {training_type}"
+                )), 400
+                
+        except Exception as e:
+            return jsonify(internal_error_response(
+                response_text=f"创建新训练数据失败: {str(e)}"
+            )), 500
+        
+        # 4. 获取更新后的总记录数
+        current_total = get_total_training_count()
+        
+        return jsonify(success_response(
+            response_text="训练数据更新成功",
+            data={
+                "original_id": original_id,
+                "new_training_id": training_id,
+                "type": training_type,
+                "current_total_count": current_total
+            }
+        ))
+        
+    except Exception as e:
+        logger.error(f"training_data_update执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="更新训练数据失败,请稍后重试"
+        )), 500
+```
+
+**返回格式:**
+
+**成功响应:**
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "response": "训练数据更新成功",
+    "original_id": "6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql",
+    "new_training_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890-sql",
+    "type": "sql",
+    "current_total_count": 1250
+  }
+}
+```
+
+**错误响应示例:**
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "response": "原始记录 6982ba18-8d0a-4cce-9d3b-922ac6ee10ac-sql 不存在或删除失败",
+    "error_type": "INVALID_PARAMS",
+    "timestamp": "2025-01-15T10:30:00.000Z",
+    "can_retry": false
+  }
+}
+```
+
+**错误响应示例:**
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "response": "缺少必需参数:data",
+    "error_type": "INVALID_PARAMS",
+    "timestamp": "2025-01-15T10:30:00.000Z",
+    "can_retry": false
+  }
+}
+```
+
+**安全性和性能考虑:**
+
+1. **事务安全**:
+   - 删除和插入操作在同一个函数中执行
+   - 如果插入失败,原始记录已被删除,需要回滚或重试机制
+   - 建议在应用层实现重试逻辑
+
+2. **数据一致性**:
+   - 先删除后插入确保ID的唯一性
+   - 新记录会获得新的UUID,避免ID冲突
+   - 保持训练数据类型的正确性
+
+3. **错误处理**:
+   - 详细的参数验证
+   - 原始记录存在性检查
+   - SQL语法验证(针对sql类型)
+   - 完整的异常处理和错误信息
+
+4. **单条处理**:
+   - 只支持单条更新,简化逻辑
+   - 减少并发冲突的可能性
+   - 更精确的错误定位和处理
+
+5. **性能优化**:
+   - 复用现有的训练函数
+   - 避免重复的数据库连接
+   - 合理的错误处理和日志记录
+
+**使用建议:**
+
+1. **前端集成**:
+   ```javascript
+   // 更新训练数据的示例
+   const updateTrainingData = async (originalId, newData) => {
+     const response = await fetch('/api/v0/training_data/update', {
+       method: 'POST',
+       headers: { 'Content-Type': 'application/json' },
+       body: JSON.stringify({
+         id: originalId,
+         training_data_type: 'sql',
+         question: newData.question,
+         sql: newData.sql
+       })
+     });
+     
+     const result = await response.json();
+     if (result.success) {
+       // 更新成功,使用新的training_id更新界面
+       const newId = result.data.new_training_id;
+       updateUIWithNewId(newId);
+     }
+   };
+   ```
+
+2. **错误处理**:
+   - 检查返回的 `success` 字段
+   - 处理各种错误情况(参数错误、记录不存在、SQL语法错误等)
+   - 显示详细的错误信息给用户
+
+3. **最佳实践**:
+   - 更新前验证原始记录是否存在
+   - 更新后验证新记录是否创建成功
+   - 使用返回的新ID更新前端界面
+
+---
+
+### 3.2 `/api/v0/training_data/upload` 新增(重点设计)
+
+**功能描述:**
+
+- 支持上传多种格式的训练数据文件(.ddl/.md/.json/.sql)
+- 直接内存处理,无需临时文件存储
+- 文件大小限制500KB(约包含100-500条记录)
+- 自动识别文件类型并解析内容
+- 复用现有trainer模块的成熟功能
+- 提供详细的处理结果反馈
+
+**支持的文件类型:**
+
+基于现有 `data_pipeline/trainer` 模块的成熟功能,API 支持以下文件类型的自动识别和处理:
+
+| 文件类型 | 扩展名 | 文件命名规则 | 复用函数 | 解析方式 |
+|---------|--------|-------------|---------|-----------|
+| ddl | `.ddl` | 任意名称 | `train_ddl_file()` → `read_file_by_delimiter()` | 建表语句,以“;”为分隔符 |
+| documentation | `.md` | 任意名称 | `train_documentation_file()` → `read_markdown_file_by_sections()` | Markdown以"##"为单位分割,普通文档以"---"分隔 |
+| question_sql_pairs | `.json` | 必须包含 `_pair`或 `_pairs` | `train_json_pairs_file()` → `train_json_question_sql_pairs()` | 问答对训练(JSON格式) |
+| formatted_pairs | `.sql` | 必须包含 `_pair`或 `_pairs` | `train_formatted_pairs_file()` → `train_formatted_question_sql_pairs()` | 问答对训练(文本格式),每对Question/SQL之间使用空行分隔 |
+| sql_examples | `.sql` | 任意名称(不包含`_pair`) | `train_sql_file()` → `read_file_by_delimiter()` | SQL示例,以";"为分隔符 |
+
+**1. DDL文件 (*.ddl):**
+
+建表语句,解析时以“;”为分隔符。
+
+```sql
+-- 建表语句
+CREATE TABLE users (
+    id SERIAL PRIMARY KEY,
+    name VARCHAR(100) NOT NULL,
+    email VARCHAR(255) UNIQUE
+);
+
+CREATE TABLE orders (
+    id SERIAL PRIMARY KEY,
+    user_id INTEGER REFERENCES users(id),
+    amount DECIMAL(10,2)
+);
+```
+
+**2. 文档文件 (*.md):**
+
+以"##"为单位进行分割,每个"##"是一个表名:
+
+```markdown
+## users表
+用户基本信息表,存储系统用户数据。
+字段说明
+- id: 用户唯一标识
+- name: 用户姓名
+- email: 电子邮箱
+
+## orders表
+订单信息表,记录用户购买记录。
+```
+
+**3. JSON问答对 (*_pair.json):**
+
+"question"/"sql" 大小写不敏感。
+
+```json
+[
+  {
+    "question": "如何查询活跃用户数量?",
+    "sql": "SELECT COUNT(*) FROM users WHERE active = true;"
+  },
+  {
+    "question": "如何统计每个分类的产品数量?",
+    "sql": "SELECT category, COUNT(*) FROM products GROUP BY category;"
+  }
+]
+```
+
+**4. 格式化问答对 (*_pair.sql):**
+
+注意,Question和SQL,大小写不敏感,但是两个问答对之间要以空行作为分隔。
+
+```
+Question: 如何查询活跃用户数量?
+SQL:SELECT COUNT(*) FROM users WHERE active = true;
+
+Question: 如何统计每个分类的产品数量?
+SQL:SELECT category, COUNT(*) FROM products GROUP BY category;
+```
+
+**API规格:**
+
+**请求参数:**
+- `file`: 上传的文件(multipart/form-data)
+- 文件大小限制:最大500KB
+- 文件编码:UTF-8
+- 自动识别文件类型:基于文件扩展名和命名规则
+
+**实现代码框架:**
+```python
+from werkzeug.utils import secure_filename
+from data_pipeline.trainer.run_training import (
+    read_file_by_delimiter,
+    read_markdown_file_by_sections,
+    train_json_question_sql_pairs,
+    train_formatted_question_sql_pairs
+)
+from data_pipeline.trainer.vanna_trainer import (
+    train_ddl, train_documentation, train_question_sql_pair, 
+    train_sql_example, flush_training
+)
+from common.result import success_response, bad_request_response, internal_error_response
+import json
+import tempfile
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+@app.route('/api/v0/training_data/upload', methods=['POST'])
+def upload_training_data():
+    try:
+        # 1. 参数验证
+        if 'file' not in request.files:
+            return jsonify(bad_request_response("未提供文件"))
+        
+        file = request.files['file']
+        if file.filename == '':
+            return jsonify(bad_request_response("未选择文件"))
+        
+        # 2. 文件大小验证 (500KB)
+        file.seek(0, 2)
+        file_size = file.tell()
+        file.seek(0)
+        
+        if file_size > 500 * 1024:  # 500KB
+            return jsonify(bad_request_response("文件大小不能超过500KB"))
+        
+        # 3. 创建临时文件(复用现有函数需要文件路径)
+        filename = secure_filename(file.filename)
+        temp_file_path = None
+        
+        try:
+            with tempfile.NamedTemporaryFile(mode='w+b', delete=False, suffix=os.path.splitext(filename)[1]) as tmp_file:
+                file.save(tmp_file.name)
+                temp_file_path = tmp_file.name
+            
+            # 4. 根据文件类型调用现有的训练函数
+            filename_lower = filename.lower()
+            success_count = 0
+            
+            if filename_lower.endswith('.ddl'):
+                success_count = train_ddl_file(temp_file_path)
+                
+            elif filename_lower.endswith(('.md', '.markdown')):
+                success_count = train_documentation_file(temp_file_path)
+                
+            elif filename_lower.endswith('_pair.json') or filename_lower.endswith('_pairs.json'):
+                success_count = train_json_pairs_file(temp_file_path)
+                
+            elif filename_lower.endswith('_pair.sql') or filename_lower.endswith('_pairs.sql'):
+                success_count = train_formatted_pairs_file(temp_file_path)
+                
+            elif filename_lower.endswith('.sql'):
+                success_count = train_sql_file(temp_file_path)
+            else:
+                return jsonify(bad_request_response("不支持的文件类型"))
+            
+            # 5. 刷新批处理
+            flush_training()
+            
+            return jsonify(success_response(
+                response_text=f"文件上传并训练成功:处理了 {success_count} 条记录",
+                data={
+                    "filename": filename,
+                    "file_size": file_size,
+                    "records_processed": success_count,
+                    "status": "completed"
+                }
+            ))
+            
+        except Exception as e:
+            logger.error(f"训练失败: {str(e)}")
+            return jsonify(internal_error_response(f"训练失败: {str(e)}"))
+        finally:
+            # 清理临时文件
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    os.unlink(temp_file_path)
+                except Exception as e:
+                    logger.warning(f"清理临时文件失败: {str(e)}")
+        
+    except Exception as e:
+        logger.error(f"文件上传失败: {str(e)}")
+        return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
+
+# 复用现有训练函数的包装器
+def train_ddl_file(file_path: str) -> int:
+    """复用 run_training.py 中的 DDL 训练逻辑"""
+    success_count = 0
+    try:
+        ddl_statements = read_file_by_delimiter(file_path, ";")
+        
+        for ddl in ddl_statements:
+            try:
+                train_ddl(ddl)
+                success_count += 1
+            except Exception as e:
+                logger.error(f"训练DDL失败: {ddl[:50]}..., 错误: {str(e)}")
+    except Exception as e:
+        logger.error(f"读取DDL文件失败: {str(e)}")
+    
+    return success_count
+
+def train_documentation_file(file_path: str) -> int:
+    """复用 run_training.py 中的文档训练逻辑"""
+    success_count = 0
+    
+    try:
+        # 检查是否为Markdown文件
+        if file_path.lower().endswith(('.md', '.markdown')):
+            sections = read_markdown_file_by_sections(file_path)
+            for section in sections:
+                try:
+                    train_documentation(section)
+                    success_count += 1
+                except Exception as e:
+                    logger.error(f"训练文档失败: {section[:50]}..., 错误: {str(e)}")
+        else:
+            # 非Markdown文件使用传统的---分隔
+            doc_blocks = read_file_by_delimiter(file_path, "---")
+            for doc in doc_blocks:
+                try:
+                    train_documentation(doc)
+                    success_count += 1
+                except Exception as e:
+                    logger.error(f"训练文档失败: {doc[:50]}..., 错误: {str(e)}")
+    except Exception as e:
+        logger.error(f"读取文档文件失败: {str(e)}")
+    
+    return success_count
+
+def train_json_pairs_file(file_path: str) -> int:
+    """复用 run_training.py 中的 JSON 问答对训练逻辑"""
+    try:
+        # 直接调用现有函数
+        train_json_question_sql_pairs(file_path)
+        
+        # 计算成功数量
+        with open(file_path, 'r', encoding='utf-8') as f:
+            pairs = json.load(f)
+            return len(pairs) if isinstance(pairs, list) else 1
+    except Exception as e:
+        logger.error(f"训练JSON问答对失败: {str(e)}")
+        return 0
+
+def train_formatted_pairs_file(file_path: str) -> int:
+    """复用 run_training.py 中的格式化问答对训练逻辑"""
+    try:
+        # 直接调用现有函数
+        train_formatted_question_sql_pairs(file_path)
+        
+        # 计算成功数量
+        with open(file_path, 'r', encoding='utf-8') as f:
+            content = f.read()
+            pairs = content.split('\n\n')
+            return len([p for p in pairs if p.strip()])
+    except Exception as e:
+        logger.error(f"训练格式化问答对失败: {str(e)}")
+        return 0
+
+def train_sql_file(file_path: str) -> int:
+    """复用 run_training.py 中的 SQL 训练逻辑"""
+    success_count = 0
+    try:
+        sql_statements = read_file_by_delimiter(file_path, ";")
+        
+        for sql in sql_statements:
+            try:
+                train_sql_example(sql)
+                success_count += 1
+            except Exception as e:
+                logger.error(f"训练SQL失败: {sql[:50]}..., 错误: {str(e)}")
+    except Exception as e:
+        logger.error(f"读取SQL文件失败: {str(e)}")
+    
+    return success_count
+
+
+```
+
+**返回格式:**
+
+**成功响应:**
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "response": "文件上传并训练成功:处理了 45 条记录",
+    "filename": "training_data.sql",
+    "file_size": 12345,
+    "records_processed": 45,
+    "status": "completed"
+  }
+}
+```
+
+**错误响应:**
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "response": "文件大小不能超过500KB",
+    "error_type": "INVALID_PARAMS",
+    "timestamp": "2025-01-15T10:30:00.000Z",
+    "can_retry": false
+  }
+}
+```
+
+**安全性和性能考虑:**
+
+1. **文件处理优化**:
+   - 使用临时文件处理,确保内存安全
+   - 文件大小限制500KB,避免内存压力
+   - 自动清理临时文件,防止磁盘空间泄露
+   - 完整的错误处理和日志记录
+
+2. **自动类型识别**:
+   - 基于文件扩展名和命名规则
+   - 智能解析不同格式内容
+   - 大小写不敏感的格式处理
+
+3. **复用成熟功能**:
+   - 直接使用 `data_pipeline.trainer.run_training` 中的成熟函数
+   - 复用 `read_file_by_delimiter` 和 `read_markdown_file_by_sections` 解析逻辑
+   - 调用 `train_json_question_sql_pairs` 和 `train_formatted_question_sql_pairs` 处理问答对
+   - 继承批处理优化机制,保持训练逻辑的一致性
+   - 避免重复造轮子,提高代码维护性
+
+4. **错误处理**:
+   - 详细的错误日志记录
+   - 逐条处理失败时的容错机制
+   - 统计处理成功和失败的记录数
+   - 临时文件清理保证不留垃圾文件
+
+5. **性能优化**:
+   - 复用现有的高效解析函数
+   - 批处理机制提高训练效率
+   - 自动刷新确保数据及时写入
+   - 避免重复开发,减少维护成本
+
+**使用建议:**
+
+1. **文件准备**:
+   - 确保文件使用UTF-8编码
+   - 控制文件大小在500KB以内
+   - 按照规定的格式和命名规则准备文件
+
+2. **文件类型选择**:
+   - DDL文件:用于训练表结构信息,复用 `train_ddl_statements` 逻辑
+   - Markdown文档:用于训练表和字段说明,复用 `train_documentation_blocks` 逻辑
+   - JSON问答对:用于训练结构化问答数据,复用 `train_json_question_sql_pairs` 逻辑
+   - 格式化问答对:用于训练文本格式问答数据,复用 `train_formatted_question_sql_pairs` 逻辑
+   - SQL文件:用于训练SQL示例,复用 `train_sql_examples` 逻辑
+
+3. **最佳实践**:
+   - 建议分批上传,每次处理一种类型的文件
+   - 上传前检查文件格式的正确性
+   - 关注返回结果中的处理统计信息
+   - 复用现有训练逻辑,确保数据处理的一致性
+
+---
+
+### 3.3 `/api/v0/training_data/combine` (新增的API)
+
+**功能描述:**
+- 合并 langchain_pg_embedding 表中相同 document 字段的重复记录
+- 支持按集合名称过滤合并范围
+- 提供预览模式(dry_run)和实际执行模式
+- 支持多种保留策略(first/last/by_metadata_time)
+
+**技术分析:**
+
+训练数据写入机制中,不同类型的数据在 document 字段中存储不同格式的内容:
+- SQL类型:JSON格式 `{"question": "...", "sql": "..."}`
+- DDL类型:直接存储DDL语句原文
+- Documentation类型:直接存储文档内容原文
+- Error_SQL类型:JSON格式错误SQL示例
+
+合并策略采用 document 字段精确匹配,因为:
+1. embedding 字段是向量,比较成本高且可能有微小差异
+2. document 字段存储确定性文本内容,便于精确比较
+3. 相同的 document 内容必然产生相同的训练效果
+
+**API规格:**
+
+**请求参数:**
+```json
+{
+  "collection_names": ["sql", "ddl", "documentation", "error_sql"],
+  "dry_run": true,
+  "keep_strategy": "first"
+}
+```
+
+**参数说明:**
+- `collection_names`: 要处理的集合名称数组,支持 ["sql", "ddl", "documentation", "error_sql"]
+- `dry_run`: 是否为预览模式,默认 true(安全)
+- `keep_strategy`: 保留策略
+  - `"first"`: 保留第一条记录(按ID排序,推荐默认)
+  - `"last"`: 保留最后一条记录(按ID排序)
+  - `"by_metadata_time"`: 按 cmetadata.createdat 时间排序保留最新记录
+
+**实现代码框架:**
+```python
+@app.route('/api/v0/training_data/combine', methods=['POST'])
+def combine_training_data():
+    try:
+        # 1. 参数验证
+        data = request.get_json()
+        if not data:
+            return jsonify(bad_request_response("请求体不能为空"))
+        
+        collection_names = data.get('collection_names', [])
+        if not collection_names or not isinstance(collection_names, list):
+            return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
+        
+        # 验证集合名称
+        valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
+        invalid_collections = [name for name in collection_names if name not in valid_collections]
+        if invalid_collections:
+            return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
+        
+        dry_run = data.get('dry_run', True)
+        keep_strategy = data.get('keep_strategy', 'first')
+        
+        if keep_strategy not in ['first', 'last', 'by_metadata_time']:
+            return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
+        
+        # 2. 获取数据库连接
+        conn = get_db_connection()
+        cursor = conn.cursor()
+        
+        # 3. 查找重复记录
+        duplicate_groups = []
+        total_before = 0
+        total_duplicates = 0
+        collections_stats = {}
+        
+        for collection_name in collection_names:
+            # 获取集合ID
+            cursor.execute(
+                "SELECT uuid FROM langchain_pg_collection WHERE name = %s",
+                (collection_name,)
+            )
+            collection_result = cursor.fetchone()
+            if not collection_result:
+                continue
+            
+            collection_id = collection_result[0]
+            
+            # 统计该集合的记录数
+            cursor.execute(
+                "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
+                (collection_id,)
+            )
+            collection_before = cursor.fetchone()[0]
+            total_before += collection_before
+            
+            # 查找重复记录
+            cursor.execute("""
+                SELECT document, COUNT(*) as duplicate_count, 
+                       array_agg(id ORDER BY %s) as record_ids
+                FROM langchain_pg_embedding 
+                WHERE collection_id = %s 
+                GROUP BY document 
+                HAVING COUNT(*) > 1
+            """ % (
+                "id" if keep_strategy in ['first', 'last'] else "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id",
+                "%s"
+            ), (collection_id,))
+            
+            collection_duplicates = 0
+            for row in cursor.fetchall():
+                document_content, duplicate_count, record_ids = row
+                collection_duplicates += duplicate_count - 1  # 减去要保留的一条
+                
+                # 根据保留策略选择要保留的记录
+                if keep_strategy == 'first':
+                    keep_id = record_ids[0]
+                    remove_ids = record_ids[1:]
+                elif keep_strategy == 'last':
+                    keep_id = record_ids[-1]
+                    remove_ids = record_ids[:-1]
+                else:  # by_metadata_time
+                    keep_id = record_ids[0]  # 已经按时间排序
+                    remove_ids = record_ids[1:]
+                
+                duplicate_groups.append({
+                    "collection_name": collection_name,
+                    "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
+                    "duplicate_count": duplicate_count,
+                    "kept_record_id" if not dry_run else "records_to_keep": keep_id,
+                    "removed_record_ids" if not dry_run else "records_to_remove": remove_ids
+                })
+            
+            total_duplicates += collection_duplicates
+            collections_stats[collection_name] = {
+                "before": collection_before,
+                "after": collection_before - collection_duplicates,
+                "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
+            }
+        
+        # 4. 执行合并操作(如果不是dry_run)
+        if not dry_run:
+            try:
+                conn.autocommit = False  # 开始事务
+                
+                for group in duplicate_groups:
+                    remove_ids = group["removed_record_ids"]
+                    if remove_ids:
+                        cursor.execute(
+                            "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
+                            (remove_ids,)
+                        )
+                
+                conn.commit()
+                
+            except Exception as e:
+                conn.rollback()
+                return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
+            finally:
+                conn.autocommit = True
+        
+        # 5. 构建响应
+        total_after = total_before - total_duplicates
+        
+        summary = {
+            "total_records_before": total_before,
+            "total_records_after": total_after,
+            "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
+            "collections_stats": collections_stats
+        }
+        
+        if dry_run:
+            response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
+            data_key = "duplicate_groups"
+        else:
+            response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
+            data_key = "merged_groups"
+        
+        return jsonify(success_response(
+            response_text=response_text,
+            data={
+                "dry_run": dry_run,
+                "collections_processed": collection_names,
+                "summary": summary,
+                data_key: duplicate_groups
+            }
+        ))
+        
+    except Exception as e:
+        return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
+    finally:
+        if 'cursor' in locals():
+            cursor.close()
+        if 'conn' in locals():
+            conn.close()
+```
+
+**返回格式(dry_run: true):**
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "response": "发现 50 条重复记录,预计删除后将从 1000 条减少到 950 条记录",
+    "dry_run": true,
+    "collections_processed": ["sql", "ddl"],
+    "summary": {
+      "total_records_before": 1000,
+      "total_records_after": 950,
+      "duplicates_to_remove": 50,
+      "collections_stats": {
+        "sql": {
+          "before": 600,
+          "after": 580,
+          "duplicates_to_remove": 20
+        },
+        "ddl": {
+          "before": 400,
+          "after": 370,
+          "duplicates_to_remove": 30
+        }
+      }
+    },
+    "duplicate_groups": [
+      {
+        "collection_name": "sql",
+        "document_content": "SELECT * FROM users WHERE id = ?",
+        "duplicate_count": 3,
+        "records_to_keep": "uuid1-sql",
+        "records_to_remove": ["uuid2-sql", "uuid3-sql"]
+      }
+    ]
+  }
+}
+```
+
+**返回格式(dry_run: false):**
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "response": "成功合并重复记录,删除了 50 条重复记录,从 1000 条减少到 950 条记录",
+    "dry_run": false,
+    "collections_processed": ["sql", "ddl"],
+    "summary": {
+      "total_records_before": 1000,
+      "total_records_after": 950,
+      "duplicates_removed": 50,
+      "collections_stats": {
+        "sql": {
+          "before": 600,
+          "after": 580,
+          "duplicates_removed": 20
+        },
+        "ddl": {
+          "before": 400,
+          "after": 370,
+          "duplicates_removed": 30
+        }
+      }
+    },
+    "merged_groups": [
+      {
+        "collection_name": "sql",
+        "document_content": "SELECT * FROM users WHERE id = ?",
+        "duplicate_count": 3,
+        "kept_record_id": "uuid1-sql",
+        "removed_record_ids": ["uuid2-sql", "uuid3-sql"]
+      }
+    ]
+  }
+}
+```
+
+**错误响应示例:**
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "response": "collection_names 参数必须是非空数组",
+    "error_type": "INVALID_PARAMS",
+    "timestamp": "2025-01-15T10:30:00.000Z",
+    "can_retry": false
+  }
+}
+```
+
+**安全性和性能考虑:**
+
+1. **事务安全**:
+   - 整个合并过程在数据库事务中执行
+   - 发生错误时自动回滚
+   - 避免数据不一致状态
+
+2. **预览模式**:
+   - 默认启用 dry_run 模式
+   - 用户可以先查看合并计划再决定是否执行
+   - 提供详细的影响分析
+
+3. **保留策略**:
+   - 支持多种保留策略满足不同需求
+   - first 策略最简单可靠,推荐作为默认
+   - by_metadata_time 策略使用真实创建时间
+
+4. **性能优化**:
+   - 使用数组聚合函数减少数据库查询次数
+   - 批量删除操作提高效率
+   - 分集合处理避免长时间锁表
+
+5. **错误处理**:
+   - 详细的参数验证
+   - 完整的异常处理和错误信息
+   - 数据库连接的正确关闭
+
+**使用建议:**
+
+1. **推荐流程**:
+   ```bash
+   # 1. 先预览合并计划
+   POST /api/v0/training_data/combine
+   {
+     "collection_names": ["sql", "ddl"],
+     "dry_run": true,
+     "keep_strategy": "first"
+   }
+   
+   # 2. 确认无误后执行实际合并
+   POST /api/v0/training_data/combine
+   {
+     "collection_names": ["sql", "ddl"],
+     "dry_run": false,
+     "keep_strategy": "first"
+   }
+   ```
+
+2. **保留策略选择**:
+   - 一般情况:使用 `"first"` 策略
+   - 需要保留最新数据:使用 `"by_metadata_time"` 策略
+   - 测试环境:可以使用 `"last"` 策略
+
+3. **批量处理**:
+   - 大量数据时建议分批处理不同集合
+   - 可以先处理较小的集合验证效果
+   - 重要数据建议先备份
+
+---
+
+## 4. 补充建议
+
+### 4.1 错误处理增强
+
+建议在所有API中添加更详细的错误处理:
+
+```python
+# 通用错误处理装饰器
+def handle_api_errors(f):
+    @wraps(f)
+    def decorated_function(*args, **kwargs):
+        try:
+            return f(*args, **kwargs)
+        except ValueError as e:
+            return jsonify(bad_request_response(str(e)))
+        except Exception as e:
+            logger.error(f"API错误: {str(e)}", exc_info=True)
+            return jsonify(internal_error_response(f"系统错误: {str(e)}"))
+    return decorated_function
+```
+
+### 4.2 批量操作优化
+
+对于大量数据的处理,建议引入任务队列:
+
+```python
+# 异步任务支持
+@app.route('/api/v0/training_data/upload_async', methods=['POST'])
+def upload_training_data_async():
+    # 创建后台任务
+    task_id = str(uuid.uuid4())
+    # 将任务放入队列
+    # 返回任务ID供客户端查询进度
+    return jsonify(success_response(
+        response_text="文件上传任务已创建",
+        data={"task_id": task_id}
+    ))
+
+@app.route('/api/v0/training_data/task_status/<task_id>', methods=['GET'])
+def get_task_status(task_id):
+    # 查询任务状态
+    pass
+```
+
+### 4.3 数据验证增强
+
+建议添加更严格的数据验证:
+
+```python
+# SQL语法验证
+def validate_sql(sql):
+    try:
+        import sqlparse
+        parsed = sqlparse.parse(sql)
+        if not parsed:
+            raise ValueError("SQL语法错误")
+        return True
+    except Exception as e:
+        raise ValueError(f"SQL验证失败: {str(e)}")
+
+# 数据类型验证
+def validate_training_data(data_type, content):
+    if data_type == 'sql':
+        return validate_sql(content)
+    elif data_type == 'ddl':
+        return validate_ddl(content)
+    # 其他验证逻辑
+```
+
+---
+
+## 5. 实施建议
+
+### 5.1 开发顺序
+
+1. **第一阶段**:实现增强的stats API(相对简单)
+2. **第二阶段**:实现update API(中等复杂度)
+3. **第三阶段**:实现upload API(最复杂,需要充分测试)
+
+### 5.2 测试策略
+
+1. **单元测试**:每个解析函数和训练函数
+2. **集成测试**:完整的API调用流程
+3. **性能测试**:大文件上传和处理
+4. **错误测试**:各种异常情况的处理
+
+### 5.3 部署注意事项
+
+1. **文件上传限制**:配置Nginx/Apache的文件大小限制
+2. **临时目录**:确保有足够的磁盘空间
+3. **数据库连接**:优化数据库连接池设置
+4. **日志监控**:添加详细的操作日志
+
+---
+
+## 6. 结论
+
+本改造方案通过充分复用现有代码,提供了更加完整和实用的解决方案:
+
+1. **update API**:简化为单条更新,采用先删除后插入策略,返回新记录ID
+2. **upload API**:复用 `data_pipeline/trainer` 模块的成熟函数,支持多种文件格式的自动解析和导入
+3. **combine API**:支持合并重复训练数据,提供预览和执行模式
+
+**主要优势:**
+
+1. **避免重复造轮子**:
+   - 直接复用 `run_training.py` 中的成熟解析函数
+   - 保持与现有训练逻辑的一致性
+   - 减少代码维护成本
+
+2. **提高代码质量**:
+   - 复用经过验证的文件处理逻辑
+   - 继承现有的错误处理机制
+   - 保持日志格式的统一性
+
+3. **性能优化**:
+   - 利用现有的批处理机制
+   - 复用高效的文件解析算法
+   - 减少开发和测试时间
+
+特别是upload API的设计,通过复用现有的trainer模块,不仅避免了重复开发,还确保了与现有系统的完美兼容性,可以满足生产环境的需求。

+ 256 - 66
unified_api.py

@@ -1202,6 +1202,110 @@ def validate_sql_syntax(sql: str) -> tuple[bool, str]:
     except Exception as e:
         return False, f"SQL语法错误:{str(e)}"
 
+def paginate_data(data_list: list, page: int, page_size: int):
+    """分页算法"""
+    total = len(data_list)
+    start_idx = (page - 1) * page_size
+    end_idx = start_idx + page_size
+    page_data = data_list[start_idx:end_idx]
+    total_pages = (total + page_size - 1) // page_size
+    
+    return {
+        "data": page_data,
+        "pagination": {
+            "page": page,
+            "page_size": page_size,
+            "total": total,
+            "total_pages": total_pages,
+            "has_next": end_idx < total,
+            "has_prev": page > 1
+        }
+    }
+
+def filter_by_type(data_list: list, training_data_type: str):
+    """按类型筛选算法"""
+    if not training_data_type:
+        return data_list
+    
+    return [
+        record for record in data_list 
+        if record.get('training_data_type') == training_data_type
+    ]
+
+def search_in_data(data_list: list, search_keyword: str):
+    """在数据中搜索关键词"""
+    if not search_keyword:
+        return data_list
+    
+    keyword_lower = search_keyword.lower()
+    return [
+        record for record in data_list
+        if (record.get('question') and keyword_lower in record['question'].lower()) or
+           (record.get('content') and keyword_lower in record['content'].lower())
+    ]
+
+def get_total_training_count():
+    """获取当前训练数据总数"""
+    try:
+        training_data = vn.get_training_data()
+        if training_data is not None and not training_data.empty:
+            return len(training_data)
+        return 0
+    except Exception as e:
+        logger.warning(f"获取训练数据总数失败: {e}")
+        return 0
+
+def process_single_training_item(item: dict, index: int) -> dict:
+    """处理单个训练数据项"""
+    training_type = item.get('training_data_type')
+    
+    if training_type == 'sql':
+        sql = item.get('sql')
+        if not sql:
+            raise ValueError("SQL字段是必需的")
+        
+        # SQL语法检查
+        is_valid, error_msg = validate_sql_syntax(sql)
+        if not is_valid:
+            raise ValueError(error_msg)
+        
+        question = item.get('question')
+        if question:
+            training_id = vn.train(question=question, sql=sql)
+        else:
+            training_id = vn.train(sql=sql)
+            
+    elif training_type == 'error_sql':
+        # error_sql不需要语法检查
+        question = item.get('question')
+        sql = item.get('sql')
+        if not question or not sql:
+            raise ValueError("question和sql字段都是必需的")
+        training_id = vn.train_error_sql(question=question, sql=sql)
+        
+    elif training_type == 'documentation':
+        content = item.get('content')
+        if not content:
+            raise ValueError("content字段是必需的")
+        training_id = vn.train(documentation=content)
+        
+    elif training_type == 'ddl':
+        ddl = item.get('ddl')
+        if not ddl:
+            raise ValueError("ddl字段是必需的")
+        training_id = vn.train(ddl=ddl)
+        
+    else:
+        raise ValueError(f"不支持的训练数据类型: {training_type}")
+    
+    return {
+        "index": index,
+        "success": True,
+        "training_id": training_id,
+        "type": training_type,
+        "message": f"{training_type}训练数据创建成功"
+    }
+
 @app.route('/api/v0/training_data/stats', methods=['GET'])
 def training_data_stats():
     """获取训练数据统计信息API"""
@@ -1213,16 +1317,44 @@ def training_data_stats():
                 response_text="统计信息获取成功",
                 data={
                     "total_count": 0,
+                    "type_breakdown": {
+                        "sql": 0,
+                        "documentation": 0,
+                        "ddl": 0,
+                        "error_sql": 0
+                    },
+                    "type_percentages": {
+                        "sql": 0.0,
+                        "documentation": 0.0,
+                        "ddl": 0.0,
+                        "error_sql": 0.0
+                    },
                     "last_updated": datetime.now().isoformat()
                 }
             ))
         
         total_count = len(training_data)
         
+        # 统计各类型数量
+        type_breakdown = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
+        
+        if 'training_data_type' in training_data.columns:
+            type_counts = training_data['training_data_type'].value_counts()
+            for data_type, count in type_counts.items():
+                if data_type in type_breakdown:
+                    type_breakdown[data_type] = int(count)
+        
+        # 计算百分比
+        type_percentages = {}
+        for data_type, count in type_breakdown.items():
+            type_percentages[data_type] = round(count / max(total_count, 1) * 100, 2)
+        
         return jsonify(success_response(
             response_text="统计信息获取成功",
             data={
                 "total_count": total_count,
+                "type_breakdown": type_breakdown,
+                "type_percentages": type_percentages,
                 "last_updated": datetime.now().isoformat()
             }
         ))
@@ -1235,18 +1367,38 @@ def training_data_stats():
 
 @app.route('/api/v0/training_data/query', methods=['POST'])
 def training_data_query():
-    """分页查询训练数据API"""
+    """分页查询训练数据API - 支持类型筛选、搜索和排序功能"""
     try:
         req = request.get_json(force=True)
         
+        # 解析参数,设置默认值
         page = req.get('page', 1)
         page_size = req.get('page_size', 20)
+        training_data_type = req.get('training_data_type')
+        sort_by = req.get('sort_by', 'id')
+        sort_order = req.get('sort_order', 'desc')
+        search_keyword = req.get('search_keyword')
         
-        if page < 1 or page_size < 1 or page_size > 100:
+        # 参数验证
+        if page < 1:
             return jsonify(bad_request_response(
-                response_text="参数错误"
+                response_text="页码必须大于0",
+                missing_params=["page"]
             )), 400
         
+        if page_size < 1 or page_size > 100:
+            return jsonify(bad_request_response(
+                response_text="每页大小必须在1-100之间",
+                missing_params=["page_size"]
+            )), 400
+        
+        if search_keyword and len(search_keyword) > 100:
+            return jsonify(bad_request_response(
+                response_text="搜索关键词最大长度为100字符",
+                missing_params=["search_keyword"]
+            )), 400
+        
+        # 获取训练数据
         training_data = vn.get_training_data()
         
         if training_data is None or training_data.empty:
@@ -1261,28 +1413,40 @@ def training_data_query():
                         "total_pages": 0,
                         "has_next": False,
                         "has_prev": False
+                    },
+                    "filters_applied": {
+                        "training_data_type": training_data_type,
+                        "search_keyword": search_keyword
                     }
                 }
             ))
         
+        # 转换为列表格式
         records = training_data.to_dict(orient="records")
-        total = len(records)
-        start_idx = (page - 1) * page_size
-        end_idx = start_idx + page_size
-        page_data = records[start_idx:end_idx]
-        total_pages = (total + page_size - 1) // page_size
+        
+        # 应用筛选条件
+        if training_data_type:
+            records = filter_by_type(records, training_data_type)
+        
+        if search_keyword:
+            records = search_in_data(records, search_keyword)
+        
+        # 排序
+        if sort_by in ['id', 'training_data_type']:
+            reverse = (sort_order.lower() == 'desc')
+            records.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
+        
+        # 分页
+        paginated_result = paginate_data(records, page, page_size)
         
         return jsonify(success_response(
-            response_text=f"查询成功,共找到 {total} 条记录",
+            response_text=f"查询成功,共找到 {paginated_result['pagination']['total']} 条记录",
             data={
-                "records": page_data,
-                "pagination": {
-                    "page": page,
-                    "page_size": page_size,
-                    "total": total,
-                    "total_pages": total_pages,
-                    "has_next": end_idx < total,
-                    "has_prev": page > 1
+                "records": paginated_result["data"],
+                "pagination": paginated_result["pagination"],
+                "filters_applied": {
+                    "training_data_type": training_data_type,
+                    "search_keyword": search_keyword
                 }
             }
         ))
@@ -1295,16 +1459,18 @@ def training_data_query():
 
 @app.route('/api/v0/training_data/create', methods=['POST'])
 def training_data_create():
-    """创建训练数据API"""
+    """创建训练数据API - 支持单条和批量创建,支持四种数据类型"""
     try:
         req = request.get_json(force=True)
         data = req.get('data')
         
         if not data:
             return jsonify(bad_request_response(
-                response_text="缺少必需参数:data"
+                response_text="缺少必需参数:data",
+                missing_params=["data"]
             )), 400
         
+        # 统一处理为列表格式
         if isinstance(data, dict):
             data_list = [data]
         elif isinstance(data, list):
@@ -1314,6 +1480,7 @@ def training_data_create():
                 response_text="data字段格式错误,应为对象或数组"
             )), 400
         
+        # 批量操作限制
         if len(data_list) > 50:
             return jsonify(bad_request_response(
                 response_text="批量操作最大支持50条记录"
@@ -1321,50 +1488,15 @@ def training_data_create():
         
         results = []
         successful_count = 0
+        type_summary = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
         
         for index, item in enumerate(data_list):
             try:
-                training_type = item.get('training_data_type')
-                
-                if training_type == 'sql':
-                    sql = item.get('sql')
-                    if not sql:
-                        raise ValueError("SQL字段是必需的")
-                    
-                    is_valid, error_msg = validate_sql_syntax(sql)
-                    if not is_valid:
-                        raise ValueError(error_msg)
-                    
-                    question = item.get('question')
-                    if question:
-                        training_id = vn.train(question=question, sql=sql)
-                    else:
-                        training_id = vn.train(sql=sql)
-                        
-                elif training_type == 'documentation':
-                    content = item.get('content')
-                    if not content:
-                        raise ValueError("content字段是必需的")
-                    training_id = vn.train(documentation=content)
-                    
-                elif training_type == 'ddl':
-                    ddl = item.get('ddl')
-                    if not ddl:
-                        raise ValueError("ddl字段是必需的")
-                    training_id = vn.train(ddl=ddl)
-                    
-                else:
-                    raise ValueError(f"不支持的训练数据类型: {training_type}")
-                
-                results.append({
-                    "index": index,
-                    "success": True,
-                    "training_id": training_id,
-                    "type": training_type,
-                    "message": f"{training_type}训练数据创建成功"
-                })
-                successful_count += 1
-                
+                result = process_single_training_item(item, index)
+                results.append(result)
+                if result['success']:
+                    successful_count += 1
+                    type_summary[result['type']] += 1
             except Exception as e:
                 results.append({
                     "index": index,
@@ -1374,26 +1506,49 @@ def training_data_create():
                     "message": "创建失败"
                 })
         
+        # 获取创建后的总记录数
+        current_total = get_total_training_count()
+        
+        # 根据实际执行结果决定响应状态
         failed_count = len(data_list) - successful_count
         
         if failed_count == 0:
+            # 全部成功
             return jsonify(success_response(
                 response_text="训练数据创建完成",
                 data={
                     "total_requested": len(data_list),
                     "successfully_created": successful_count,
                     "failed_count": failed_count,
-                    "results": results
+                    "results": results,
+                    "summary": type_summary,
+                    "current_total_count": current_total
                 }
             ))
+        elif successful_count == 0:
+            # 全部失败
+            return jsonify(error_response(
+                response_text="训练数据创建失败",
+                data={
+                    "total_requested": len(data_list),
+                    "successfully_created": successful_count,
+                    "failed_count": failed_count,
+                    "results": results,
+                    "summary": type_summary,
+                    "current_total_count": current_total
+                }
+            )), 400
         else:
+            # 部分成功,部分失败
             return jsonify(error_response(
                 response_text=f"训练数据创建部分成功,成功{successful_count}条,失败{failed_count}条",
                 data={
                     "total_requested": len(data_list),
                     "successfully_created": successful_count,
                     "failed_count": failed_count,
-                    "results": results
+                    "results": results,
+                    "summary": type_summary,
+                    "current_total_count": current_total
                 }
             )), 207
         
@@ -1405,7 +1560,7 @@ def training_data_create():
 
 @app.route('/api/v0/training_data/delete', methods=['POST'])
 def training_data_delete():
-    """删除训练数据API"""
+    """删除训练数据API - 支持批量删除"""
     try:
         req = request.get_json(force=True)
         ids = req.get('ids', [])
@@ -1413,7 +1568,8 @@ def training_data_delete():
         
         if not ids or not isinstance(ids, list):
             return jsonify(bad_request_response(
-                response_text="缺少有效的ID列表"
+                response_text="缺少有效的ID列表",
+                missing_params=["ids"]
             )), 400
         
         if not confirm:
@@ -1421,6 +1577,7 @@ def training_data_delete():
                 response_text="删除操作需要确认,请设置confirm为true"
             )), 400
         
+        # 批量操作限制
         if len(ids) > 50:
             return jsonify(bad_request_response(
                 response_text="批量删除最大支持50条记录"
@@ -1428,6 +1585,7 @@ def training_data_delete():
         
         deleted_ids = []
         failed_ids = []
+        failed_details = []
         
         for training_id in ids:
             try:
@@ -1436,12 +1594,25 @@ def training_data_delete():
                     deleted_ids.append(training_id)
                 else:
                     failed_ids.append(training_id)
+                    failed_details.append({
+                        "id": training_id,
+                        "error": "记录不存在或删除失败"
+                    })
             except Exception as e:
                 failed_ids.append(training_id)
+                failed_details.append({
+                    "id": training_id,
+                    "error": str(e)
+                })
         
+        # 获取删除后的总记录数
+        current_total = get_total_training_count()
+        
+        # 根据实际执行结果决定响应状态
         failed_count = len(failed_ids)
         
         if failed_count == 0:
+            # 全部成功
             return jsonify(success_response(
                 response_text="训练数据删除完成",
                 data={
@@ -1449,10 +1620,27 @@ def training_data_delete():
                     "successfully_deleted": len(deleted_ids),
                     "failed_count": failed_count,
                     "deleted_ids": deleted_ids,
-                    "failed_ids": failed_ids
+                    "failed_ids": failed_ids,
+                    "failed_details": failed_details,
+                    "current_total_count": current_total
                 }
             ))
+        elif len(deleted_ids) == 0:
+            # 全部失败
+            return jsonify(error_response(
+                response_text="训练数据删除失败",
+                data={
+                    "total_requested": len(ids),
+                    "successfully_deleted": len(deleted_ids),
+                    "failed_count": failed_count,
+                    "deleted_ids": deleted_ids,
+                    "failed_ids": failed_ids,
+                    "failed_details": failed_details,
+                    "current_total_count": current_total
+                }
+            )), 400
         else:
+            # 部分成功,部分失败
             return jsonify(error_response(
                 response_text=f"训练数据删除部分成功,成功{len(deleted_ids)}条,失败{failed_count}条",
                 data={
@@ -1460,7 +1648,9 @@ def training_data_delete():
                     "successfully_deleted": len(deleted_ids),
                     "failed_count": failed_count,
                     "deleted_ids": deleted_ids,
-                    "failed_ids": failed_ids
+                    "failed_ids": failed_ids,
+                    "failed_details": failed_details,
+                    "current_total_count": current_total
                 }
             )), 207