Pārlūkot izejas kodu

增加data_training的三个API,增加langgraph的递归数量的计数器。

wangxq 1 mēnesi atpakaļ
vecāks
revīzija
27ee386b02
5 mainītis faili ar 792 papildinājumiem un 200 dzēšanām
  1. 1 1
      config/logging_config.yaml
  2. 221 165
      docs/4.训练数据API改造方案.md
  3. 101 34
      react_agent/agent.py
  4. 3 0
      react_agent/config.py
  5. 466 0
      unified_api.py

+ 1 - 1
config/logging_config.yaml

@@ -91,7 +91,7 @@ modules:
     level: DEBUG
     console:
       enabled: true
-      level: DEBUG
+      level: INFO
       format: "%(asctime)s [%(levelname)s] ReactAgent: %(message)s"
     file:
       enabled: true

+ 221 - 165
docs/4.训练数据API改造方案.md

@@ -325,28 +325,30 @@ def training_data_update():
 
 **功能描述:**
 
-- 支持上传多种格式的训练数据文件(.ddl/.md/.json/.sql
-- 直接内存处理,无需临时文件存储
+- 支持上传多种格式的训练数据文件(txt/md/sql/json 或无扩展名
+- 通过 `file_type` 参数指定文件类型,不再依赖文件扩展名
 - 文件大小限制500KB(约包含100-500条记录)
-- 自动识别文件类型并解析内容
+- 对文件内容进行格式验证,确保符合指定类型要求
 - 复用现有trainer模块的成熟功能
 - 提供详细的处理结果反馈
 
 **支持的文件类型:**
 
-基于现有 `data_pipeline/trainer` 模块的成熟功能,API 支持以下文件类型的自动识别和处理
+API 支持以下5种文件类型,通过 `file_type` 参数指定
 
-| 文件类型 | 扩展名 | 文件命名规则 | 复用函数 | 解析方式 |
-|---------|--------|-------------|---------|-----------|
-| ddl | `.ddl` | 任意名称 | `train_ddl_file()` → `read_file_by_delimiter()` | 建表语句,以“;”为分隔符 |
-| documentation | `.md` | 任意名称 | `train_documentation_file()` → `read_markdown_file_by_sections()` | Markdown以"##"为单位分割,普通文档以"---"分隔 |
-| question_sql_pairs | `.json` | 必须包含 `_pair`或 `_pairs` | `train_json_pairs_file()` → `train_json_question_sql_pairs()` | 问答对训练(JSON格式) |
-| formatted_pairs | `.sql` | 必须包含 `_pair`或 `_pairs` | `train_formatted_pairs_file()` → `train_formatted_question_sql_pairs()` | 问答对训练(文本格式),每对Question/SQL之间使用空行分隔 |
-| sql_examples | `.sql` | 任意名称(不包含`_pair`) | `train_sql_file()` → `read_file_by_delimiter()` | SQL示例,以";"为分隔符 |
+| file_type | 支持扩展名 | 内容格式验证 | 复用函数 | 解析方式 |
+|-----------|------------|-------------|---------|-----------|
+| `ddl` | **ddl**/sql/txt/无扩展名 | 必须包含 `CREATE` 语句(大小写不敏感) | `train_ddl_statements()` | 建表语句,以";"为分隔符 |
+| `markdown` | **md/markdown** | 必须包含 `##` 标题 | `train_documentation_blocks()` | Markdown以"##"为单位分割 |
+| `sql_pair_json` | **json**/txt/无扩展名 | 必须是有效JSON格式,包含question/sql字段 | `train_json_question_sql_pairs()` | JSON问答对训练 |
+| `sql_pair` | **sql**/txt/无扩展名 | 必须包含 `Question:`/`SQL:` 格式 | `train_formatted_question_sql_pairs()` | 格式化问答对训练 |
+| `sql` | **sql**/txt/无扩展名 | 必须包含 `;` 分隔符 | `train_sql_examples()` | SQL示例,以";"为分隔符 |
 
-**1. DDL文件 (*.ddl):**
+**文件类型详细说明:**
 
-建表语句,解析时以“;”为分隔符。
+**1. DDL文件 (file_type: `ddl`):**
+
+建表语句,必须包含CREATE语句,以";"为分隔符。
 
 ```sql
 -- 建表语句
@@ -363,9 +365,9 @@ CREATE TABLE orders (
 );
 ```
 
-**2. 文档文件 (*.md):**
+**2. Markdown文档 (file_type: `markdown`):**
 
-以"##"为单位进行分割,每个"##"是一个表名
+必须包含"##"标题,以"##"为单位进行分割:
 
 ```markdown
 ## users表
@@ -379,9 +381,9 @@ CREATE TABLE orders (
 订单信息表,记录用户购买记录。
 ```
 
-**3. JSON问答对 (*_pair.json):**
+**3. JSON问答对 (file_type: `sql_pair_json`):**
 
-"question"/"sql" 大小写不敏感。
+必须是有效JSON格式,"question"/"sql" 大小写不敏感。
 
 ```json
 [
@@ -396,49 +398,104 @@ CREATE TABLE orders (
 ]
 ```
 
-**4. 格式化问答对 (*_pair.sql):**
+**4. 格式化问答对 (file_type: `sql_pair`):**
 
-注意,Question和SQL,大小写不敏感,但是两个问答对之间要以空行作为分隔。
+必须包含Question/SQL格式,大小写不敏感,问答对之间用空行分隔。
 
 ```
-Question: 如何查询活跃用户数量
-SQL:SELECT COUNT(*) FROM users WHERE active = true;
+Question: 如何查询观影人数超过 10000 的热门电影
+SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
 
 Question: 如何统计每个分类的产品数量?
-SQL:SELECT category, COUNT(*) FROM products GROUP BY category;
+SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
+```
+
+**5. SQL示例 (file_type: `sql`):**
+
+纯SQL语句,必须包含";"分隔符,系统会自动生成对应问题。
+
+```sql
+SELECT COUNT(*) FROM users WHERE active = true;
+SELECT category, COUNT(*) FROM products GROUP BY category;
 ```
 
 **API规格:**
 
 **请求参数:**
 - `file`: 上传的文件(multipart/form-data)
+- `file_type`: 文件类型(必需),可选值:`ddl`、`markdown`、`sql_pair_json`、`sql_pair`、`sql`
 - 文件大小限制:最大500KB
 - 文件编码:UTF-8
-- 自动识别文件类型:基于文件扩展名和命名规则
+- 支持的文件扩展名:根据 `file_type` 动态确定(详见上表)
+
+**文件扩展名验证规则:**
+
+系统根据 `file_type` 参数动态确定允许的文件扩展名:
+
+1. **DDL类型**:允许 `.ddl`、`.sql`、`.txt` 扩展名或无扩展名
+2. **Markdown类型**:仅允许 `.md`、`.markdown` 扩展名(不支持无扩展名)
+3. **SQL_PAIR_JSON类型**:允许 `.json`、`.txt` 扩展名或无扩展名
+4. **SQL_PAIR类型**:允许 `.sql`、`.txt` 扩展名或无扩展名
+5. **SQL类型**:允许 `.sql`、`.txt` 扩展名或无扩展名
+
+**扩展名验证错误示例:**
+```
+文件类型 ddl 不支持的文件扩展名:pdf,支持的扩展名:ddl, sql, txt 或无扩展名
+文件类型 markdown 不支持的文件扩展名:txt,支持的扩展名:md, markdown
+```
+
+**内容格式验证规则:**
+
+1. **DDL类型验证**:
+   - 文件内容必须包含 `CREATE` 关键字(大小写不敏感)
+   - 验证失败返回:`"文件内容不符合DDL格式,必须包含CREATE语句"`
+
+2. **Markdown类型验证**:
+   - 文件内容必须包含 `##` 标题标记
+   - 验证失败返回:`"文件内容不符合Markdown格式,必须包含##标题"`
+
+3. **SQL_PAIR_JSON类型验证**:
+   - 文件必须是有效的JSON格式
+   - JSON内容必须包含question和sql字段
+   - 验证失败返回:`"文件内容不符合JSON问答对格式"`
+
+4. **SQL_PAIR类型验证**:
+   - 文件内容必须包含 `Question:` 和 `SQL:` 标记(大小写不敏感)
+   - 验证失败返回:`"文件内容不符合问答对格式,必须包含Question:和SQL:"`
+
+5. **SQL类型验证**:
+   - 文件内容必须包含 `;` 分隔符
+   - 验证失败返回:`"文件内容不符合SQL格式,必须包含;分隔符"`
 
 **实现代码框架:**
 ```python
+import json
+import re
 from werkzeug.utils import secure_filename
 from data_pipeline.trainer.run_training import (
-    read_file_by_delimiter,
-    read_markdown_file_by_sections,
+    train_ddl_statements,
+    train_documentation_blocks, 
     train_json_question_sql_pairs,
-    train_formatted_question_sql_pairs
-)
-from data_pipeline.trainer.vanna_trainer import (
-    train_ddl, train_documentation, train_question_sql_pair, 
-    train_sql_example, flush_training
+    train_formatted_question_sql_pairs,
+    train_sql_examples
 )
 from common.result import success_response, bad_request_response, internal_error_response
-import json
-import tempfile
-import os
-import logging
 
-logger = logging.getLogger(__name__)
+def get_allowed_extensions(file_type: str) -> list:
+    """根据文件类型返回允许的扩展名"""
+    type_specific_extensions = {
+        'ddl': ['ddl', 'sql', 'txt', ''],  # 支持无扩展名
+        'markdown': ['md', 'markdown'],    # 不支持无扩展名
+        'sql_pair_json': ['json', 'txt', ''],  # 支持无扩展名
+        'sql_pair': ['sql', 'txt', ''],   # 支持无扩展名
+        'sql': ['sql', 'txt', '']         # 支持无扩展名
+    }
+    
+    return type_specific_extensions.get(file_type, [])
 
 @app.route('/api/v0/training_data/upload', methods=['POST'])
 def upload_training_data():
+    """上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
     try:
         # 1. 参数验证
         if 'file' not in request.files:
@@ -448,6 +505,16 @@ def upload_training_data():
         if file.filename == '':
             return jsonify(bad_request_response("未选择文件"))
         
+        # 获取file_type参数
+        file_type = request.form.get('file_type')
+        if not file_type:
+            return jsonify(bad_request_response("缺少必需参数:file_type"))
+        
+        # 验证file_type参数
+        valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
+        if file_type not in valid_file_types:
+            return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
+        
         # 2. 文件大小验证 (500KB)
         file.seek(0, 2)
         file_size = file.tell()
@@ -456,45 +523,59 @@ def upload_training_data():
         if file_size > 500 * 1024:  # 500KB
             return jsonify(bad_request_response("文件大小不能超过500KB"))
         
-        # 3. 创建临时文件(复用现有函数需要文件路径
+        # 3. 验证文件扩展名(基于file_type
         filename = secure_filename(file.filename)
+        allowed_extensions = get_allowed_extensions(file_type)
+        file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
+        
+        if file_ext not in allowed_extensions:
+            # 构建友好的错误信息
+            non_empty_extensions = [ext for ext in allowed_extensions if ext]
+            if '' in allowed_extensions:
+                ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
+            else:
+                ext_message = ', '.join(non_empty_extensions)
+            return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
+        
+        # 4. 读取文件内容并验证格式
+        file.seek(0)
+        content = file.read().decode('utf-8')
+        
+        # 格式验证
+        validation_result = validate_file_content(content, file_type)
+        if not validation_result['valid']:
+            return jsonify(bad_request_response(validation_result['error']))
+        
+        # 5. 创建临时文件(复用现有函数需要文件路径)
         temp_file_path = None
         
         try:
-            with tempfile.NamedTemporaryFile(mode='w+b', delete=False, suffix=os.path.splitext(filename)[1]) as tmp_file:
-                file.save(tmp_file.name)
+            with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
+                tmp_file.write(content)
                 temp_file_path = tmp_file.name
             
-            # 4. 根据文件类型调用现有的训练函数
-            filename_lower = filename.lower()
-            success_count = 0
-            
-            if filename_lower.endswith('.ddl'):
-                success_count = train_ddl_file(temp_file_path)
+            # 6. 根据文件类型调用现有的训练函数
+            if file_type == 'ddl':
+                train_ddl_statements(temp_file_path)
                 
-            elif filename_lower.endswith(('.md', '.markdown')):
-                success_count = train_documentation_file(temp_file_path)
+            elif file_type == 'markdown':
+                train_documentation_blocks(temp_file_path)
                 
-            elif filename_lower.endswith('_pair.json') or filename_lower.endswith('_pairs.json'):
-                success_count = train_json_pairs_file(temp_file_path)
+            elif file_type == 'sql_pair_json':
+                train_json_question_sql_pairs(temp_file_path)
                 
-            elif filename_lower.endswith('_pair.sql') or filename_lower.endswith('_pairs.sql'):
-                success_count = train_formatted_pairs_file(temp_file_path)
+            elif file_type == 'sql_pair':
+                train_formatted_question_sql_pairs(temp_file_path)
                 
-            elif filename_lower.endswith('.sql'):
-                success_count = train_sql_file(temp_file_path)
-            else:
-                return jsonify(bad_request_response("不支持的文件类型"))
-            
-            # 5. 刷新批处理
-            flush_training()
+            elif file_type == 'sql':
+                train_sql_examples(temp_file_path)
             
             return jsonify(success_response(
-                response_text=f"文件上传并训练成功:处理了 {success_count} 条记录",
+                response_text=f"文件上传并训练成功:{filename}",
                 data={
                     "filename": filename,
+                    "file_type": file_type,
                     "file_size": file_size,
-                    "records_processed": success_count,
                     "status": "completed"
                 }
             ))
@@ -514,99 +595,56 @@ def upload_training_data():
         logger.error(f"文件上传失败: {str(e)}")
         return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
 
-# 复用现有训练函数的包装器
-def train_ddl_file(file_path: str) -> int:
-    """复用 run_training.py 中的 DDL 训练逻辑"""
-    success_count = 0
+def validate_file_content(content: str, file_type: str) -> dict:
+    """验证文件内容格式"""
     try:
-        ddl_statements = read_file_by_delimiter(file_path, ";")
-        
-        for ddl in ddl_statements:
+        if file_type == 'ddl':
+            # 检查是否包含CREATE语句
+            if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
+            
+        elif file_type == 'markdown':
+            # 检查是否包含##标题
+            if '##' not in content:
+                return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
+            
+        elif file_type == 'sql_pair_json':
+            # 检查是否为有效JSON
             try:
-                train_ddl(ddl)
-                success_count += 1
-            except Exception as e:
-                logger.error(f"训练DDL失败: {ddl[:50]}..., 错误: {str(e)}")
-    except Exception as e:
-        logger.error(f"读取DDL文件失败: {str(e)}")
-    
-    return success_count
-
-def train_documentation_file(file_path: str) -> int:
-    """复用 run_training.py 中的文档训练逻辑"""
-    success_count = 0
-    
-    try:
-        # 检查是否为Markdown文件
-        if file_path.lower().endswith(('.md', '.markdown')):
-            sections = read_markdown_file_by_sections(file_path)
-            for section in sections:
-                try:
-                    train_documentation(section)
-                    success_count += 1
-                except Exception as e:
-                    logger.error(f"训练文档失败: {section[:50]}..., 错误: {str(e)}")
-        else:
-            # 非Markdown文件使用传统的---分隔
-            doc_blocks = read_file_by_delimiter(file_path, "---")
-            for doc in doc_blocks:
-                try:
-                    train_documentation(doc)
-                    success_count += 1
-                except Exception as e:
-                    logger.error(f"训练文档失败: {doc[:50]}..., 错误: {str(e)}")
-    except Exception as e:
-        logger.error(f"读取文档文件失败: {str(e)}")
-    
-    return success_count
-
-def train_json_pairs_file(file_path: str) -> int:
-    """复用 run_training.py 中的 JSON 问答对训练逻辑"""
-    try:
-        # 直接调用现有函数
-        train_json_question_sql_pairs(file_path)
-        
-        # 计算成功数量
-        with open(file_path, 'r', encoding='utf-8') as f:
-            pairs = json.load(f)
-            return len(pairs) if isinstance(pairs, list) else 1
-    except Exception as e:
-        logger.error(f"训练JSON问答对失败: {str(e)}")
-        return 0
-
-def train_formatted_pairs_file(file_path: str) -> int:
-    """复用 run_training.py 中的格式化问答对训练逻辑"""
-    try:
-        # 直接调用现有函数
-        train_formatted_question_sql_pairs(file_path)
+                data = json.loads(content)
+                if not isinstance(data, list) or not data:
+                    return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
+                
+                # 检查是否包含question和sql字段
+                for item in data:
+                    if not isinstance(item, dict):
+                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
+                    
+                    has_question = any(key.lower() == 'question' for key in item.keys())
+                    has_sql = any(key.lower() == 'sql' for key in item.keys())
+                    
+                    if not has_question or not has_sql:
+                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
+                        
+            except json.JSONDecodeError:
+                return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
+            
+        elif file_type == 'sql_pair':
+            # 检查是否包含Question:和SQL:
+            if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
+            if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
+            
+        elif file_type == 'sql':
+            # 检查是否包含;分隔符
+            if ';' not in content:
+                return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
         
-        # 计算成功数量
-        with open(file_path, 'r', encoding='utf-8') as f:
-            content = f.read()
-            pairs = content.split('\n\n')
-            return len([p for p in pairs if p.strip()])
-    except Exception as e:
-        logger.error(f"训练格式化问答对失败: {str(e)}")
-        return 0
-
-def train_sql_file(file_path: str) -> int:
-    """复用 run_training.py 中的 SQL 训练逻辑"""
-    success_count = 0
-    try:
-        sql_statements = read_file_by_delimiter(file_path, ";")
+        return {'valid': True}
         
-        for sql in sql_statements:
-            try:
-                train_sql_example(sql)
-                success_count += 1
-            except Exception as e:
-                logger.error(f"训练SQL失败: {sql[:50]}..., 错误: {str(e)}")
     except Exception as e:
-        logger.error(f"读取SQL文件失败: {str(e)}")
-    
-    return success_count
-
-
+        return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}
 ```
 
 **返回格式:**
@@ -618,10 +656,10 @@ def train_sql_file(file_path: str) -> int:
   "success": true,
   "message": "操作成功",
   "data": {
-    "response": "文件上传并训练成功:处理了 45 条记录",
+    "response": "文件上传并训练成功:training_data.sql",
     "filename": "training_data.sql",
+    "file_type": "sql",
     "file_size": 12345,
-    "records_processed": 45,
     "status": "completed"
   }
 }
@@ -634,7 +672,7 @@ def train_sql_file(file_path: str) -> int:
   "success": false,
   "message": "请求参数错误",
   "data": {
-    "response": "文件大小不能超过500KB",
+    "response": "文件内容不符合SQL格式,必须包含;分隔符",
     "error_type": "INVALID_PARAMS",
     "timestamp": "2025-01-15T10:30:00.000Z",
     "can_retry": false
@@ -650,28 +688,28 @@ def train_sql_file(file_path: str) -> int:
    - 自动清理临时文件,防止磁盘空间泄露
    - 完整的错误处理和日志记录
 
-2. **自动类型识别**:
-   - 基于文件扩展名和命名规则
-   - 智能解析不同格式内容
+2. **类型验证机制**:
+   - 基于 `file_type` 参数动态确定允许的扩展名
+   - 对文件内容进行格式验证
    - 大小写不敏感的格式处理
+   - 详细的验证错误信息,包含具体的文件类型和支持的扩展名
 
 3. **复用成熟功能**:
    - 直接使用 `data_pipeline.trainer.run_training` 中的成熟函数
-   - 复用 `read_file_by_delimiter` 和 `read_markdown_file_by_sections` 解析逻辑
+   - 复用 `train_ddl_statements`、`train_documentation_blocks` 等处理逻辑
    - 调用 `train_json_question_sql_pairs` 和 `train_formatted_question_sql_pairs` 处理问答对
    - 继承批处理优化机制,保持训练逻辑的一致性
    - 避免重复造轮子,提高代码维护性
 
 4. **错误处理**:
    - 详细的错误日志记录
-   - 逐条处理失败时的容错机制
-   - 统计处理成功和失败的记录数
+   - 文件格式验证失败时的明确提示
+   - 基于文件类型的精确扩展名验证和友好提示
    - 临时文件清理保证不留垃圾文件
 
 5. **性能优化**:
    - 复用现有的高效解析函数
    - 批处理机制提高训练效率
-   - 自动刷新确保数据及时写入
    - 避免重复开发,减少维护成本
 
 **使用建议:**
@@ -679,20 +717,38 @@ def train_sql_file(file_path: str) -> int:
 1. **文件准备**:
    - 确保文件使用UTF-8编码
    - 控制文件大小在500KB以内
-   - 按照规定的格式和命名规则准备文件
+   - 按照指定的格式准备文件内容
 
 2. **文件类型选择**:
-   - DDL文件:用于训练表结构信息,复用 `train_ddl_statements` 逻辑
-   - Markdown文档:用于训练表和字段说明,复用 `train_documentation_blocks` 逻辑
-   - JSON问答对:用于训练结构化问答数据,复用 `train_json_question_sql_pairs` 逻辑
-   - 格式化问答对:用于训练文本格式问答数据,复用 `train_formatted_question_sql_pairs` 逻辑
-   - SQL文件:用于训练SQL示例,复用 `train_sql_examples` 逻辑
+   - `ddl`:用于训练表结构信息,支持 .ddl/.sql/.txt 扩展名或无扩展名,必须包含CREATE语句
+   - `markdown`:用于训练表和字段说明,仅支持 .md/.markdown 扩展名,必须包含##标题
+   - `sql_pair_json`:用于训练结构化问答数据,支持 .json/.txt 扩展名或无扩展名,必须是有效JSON格式
+   - `sql_pair`:用于训练文本格式问答数据,支持 .sql/.txt 扩展名或无扩展名,必须包含Question:/SQL:格式
+   - `sql`:用于训练SQL示例,支持 .sql/.txt 扩展名或无扩展名,必须包含;分隔符,系统会自动生成问题
 
 3. **最佳实践**:
    - 建议分批上传,每次处理一种类型的文件
    - 上传前检查文件格式的正确性
-   - 关注返回结果中的处理统计信息
-   - 复用现有训练逻辑,确保数据处理的一致性
+   - 明确指定正确的 `file_type` 参数
+   - 关注返回结果中的处理状态信息
+
+**Postman测试示例:**
+
+```
+POST /api/v0/training_data/upload
+Content-Type: multipart/form-data
+
+form-data:
+- file: [选择文件]
+- file_type: sql_pair
+
+示例文件内容:
+Question: 如何查询观影人数超过 10000 的热门电影?
+SQL: SELECT name, viewers FROM movies WHERE viewers > 10000;
+
+Question: 如何统计每个分类的产品数量?
+SQL: SELECT category, COUNT(*) FROM products GROUP BY category;
+```
 
 ---
 

+ 101 - 34
react_agent/agent.py

@@ -209,16 +209,66 @@ class CustomReactAgent:
 
     async def _async_should_continue(self, state: AgentState) -> str:
         """异步判断是继续调用工具还是结束。"""
-        last_message = state["messages"][-1]
-        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
+        thread_id = state.get("thread_id", "unknown")
+        messages = state["messages"]
+        total_messages = len(messages)
+        
+        # 显示当前递归计数
+        current_count = getattr(self, '_recursion_count', 0)
+        
+        logger.info(f"🔄 [Decision] _async_should_continue - Thread: {thread_id} | 递归计数: {current_count}/{config.RECURSION_LIMIT}")
+        logger.info(f"   消息总数: {total_messages}")
+        
+        if not messages:
+            logger.warning("   ⚠️ 消息列表为空,返回 'end'")
+            return "end"
+        
+        last_message = messages[-1]
+        message_type = type(last_message).__name__
+        
+        logger.info(f"   最后消息类型: {message_type}")
+        
+        # 检查是否有tool_calls
+        has_tool_calls = hasattr(last_message, "tool_calls") and last_message.tool_calls
+        
+        if has_tool_calls:
+            tool_calls_count = len(last_message.tool_calls)
+            logger.info(f"   发现工具调用: {tool_calls_count} 个")
+            
+            # 详细记录每个工具调用
+            for i, tool_call in enumerate(last_message.tool_calls):
+                tool_name = tool_call.get('name', 'unknown')
+                tool_id = tool_call.get('id', 'unknown')
+                logger.info(f"     工具调用[{i}]: {tool_name} (ID: {tool_id})")
+            
+            logger.info("   🔄 决策: continue (继续工具调用)")
             return "continue"
-        return "end"
+        else:
+            logger.info("   ✅ 无工具调用")
+            
+            # 检查消息内容以了解为什么结束
+            if hasattr(last_message, 'content'):
+                content_preview = str(last_message.content)[:100] + "..." if len(str(last_message.content)) > 100 else str(last_message.content)
+                logger.info(f"   消息内容预览: {content_preview}")
+            
+            logger.info("   🏁 决策: end (结束对话)")
+            return "end"
 
     async def _async_agent_node(self, state: AgentState) -> Dict[str, Any]:
         """异步Agent 节点:使用异步LLM调用。"""
-        logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']}")
+        # 增加递归计数
+        if hasattr(self, '_recursion_count'):
+            self._recursion_count += 1
+        else:
+            self._recursion_count = 1
+            
+        logger.info(f"🧠 [Async Node] agent - Thread: {state['thread_id']} | 递归计数: {self._recursion_count}/{config.RECURSION_LIMIT}")
+        
+        # 获取建议的下一步操作
+        next_step = state.get("suggested_next_step")
         
-        messages_for_llm = list(state["messages"])
+        # 构建发送给LLM的消息列表
+        messages_for_llm = state["messages"].copy()
         
         # 🎯 添加数据库范围系统提示词(每次用户提问时添加)
         if isinstance(state["messages"][-1], HumanMessage):
@@ -228,7 +278,6 @@ class CustomReactAgent:
                 logger.info("   ✅ 已添加数据库范围判断提示词")
         
         # 检查是否需要分析验证错误
-        next_step = state.get("suggested_next_step")
         
         # 行为指令与工具建议分离
         real_tools = {'valid_sql', 'run_sql'}
@@ -525,12 +574,18 @@ class CustomReactAgent:
         logger.info(" ~" * 10 + " State Print End" + " ~" * 10)
 
     async def _async_prepare_tool_input_node(self, state: AgentState) -> Dict[str, Any]:
-        """
-        准备工具输入。
-        - 强制修正generate_sql的question参数,确保使用用户原始问题。
-        - 为generate_sql注入经过严格过滤的、干净的对话历史。
-        """
-        last_message = state['messages'][-1]
+        """异步准备工具输入节点:为generate_sql工具注入history_messages。"""
+        # 增加递归计数
+        if hasattr(self, '_recursion_count'):
+            self._recursion_count += 1
+        else:
+            self._recursion_count = 1
+            
+        logger.info(f"🔧 [Async Node] prepare_tool_input - Thread: {state['thread_id']} | 递归计数: {self._recursion_count}/{config.RECURSION_LIMIT}")
+        
+        # 获取最后一条消息(应该是来自agent的AIMessage)
+        last_message = state["messages"][-1]
+
         if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
             return {"messages": [last_message]}
 
@@ -617,15 +672,19 @@ class CustomReactAgent:
         return clean_history
 
     async def _async_update_state_after_tool_node(self, state: AgentState) -> Dict[str, Any]:
-        """在工具执行后,更新 suggested_next_step 并清理参数。"""
-        logger.info(f"📝 [Node] update_state_after_tool - Thread: {state['thread_id']}")
-        
-        # 🎯 打印 state 全部信息
-        self._print_state_info(state, "update_state_after_tool")
+        """异步更新工具执行后的状态。"""
+        # 增加递归计数
+        if hasattr(self, '_recursion_count'):
+            self._recursion_count += 1
+        else:
+            self._recursion_count = 1
+            
+        logger.info(f"📝 [Async Node] update_state_after_tool - Thread: {state['thread_id']} | 递归计数: {self._recursion_count}/{config.RECURSION_LIMIT}")
         
-        last_tool_message = state['messages'][-1]
-        tool_name = last_tool_message.name
-        tool_output = last_tool_message.content
+        # 获取最后一条工具消息
+        last_message = state["messages"][-1]
+        tool_name = last_message.name
+        tool_output = last_message.content
         next_step = None
 
         if tool_name == 'generate_sql':
@@ -660,15 +719,17 @@ class CustomReactAgent:
                         logger.info(f"   已将 generate_sql 的 history_messages 设置为空字符串")
 
     async def _async_format_final_response_node(self, state: AgentState) -> Dict[str, Any]:
-        """异步最终输出格式化节点。"""
-        logger.info(f"🎨 [Async Node] format_final_response - Thread: {state['thread_id']}")
-        
-        # 保持原有的消息格式化(用于shell.py兼容)
-        last_message = state['messages'][-1]
-        # 注释掉前缀添加,直接使用原始内容
-        # last_message.content = f"[Formatted Output]\n{last_message.content}"
+        """异步格式化最终响应节点。"""
+        # 增加递归计数
+        if hasattr(self, '_recursion_count'):
+            self._recursion_count += 1
+        else:
+            self._recursion_count = 1
+            
+        logger.info(f"✨ [Async Node] format_final_response - Thread: {state['thread_id']} | 递归计数: {self._recursion_count}/{config.RECURSION_LIMIT}")
         
-        return {"messages": [last_message]}
+        # 这个节点主要用于最终处理,通常不需要修改状态
+        return {"messages": state["messages"]}
 
     async def _async_generate_api_data(self, state: AgentState) -> Dict[str, Any]:
         """异步生成API格式的数据结构"""
@@ -867,12 +928,18 @@ class CustomReactAgent:
             thread_id = f"{user_id}:{now.strftime('%Y%m%d%H%M%S')}{milliseconds:03d}"
             logger.info(f"🆕 新建会话,Thread ID: {thread_id}")
         
-        config = {
+        # 初始化递归计数器(用于日志显示)
+        self._recursion_count = 0
+        
+        run_config = {
             "configurable": {
                 "thread_id": thread_id,
-            }
+            },
+            "recursion_limit": config.RECURSION_LIMIT
         }
         
+        logger.info(f"🔢 递归限制设置: {config.RECURSION_LIMIT}")
+        
         inputs = {
             "messages": [HumanMessage(content=message)],
             "user_id": user_id,
@@ -899,7 +966,7 @@ class CustomReactAgent:
                     else:
                         logger.warning(f"⚠️ Checkpointer测试失败,但继续执行: {checkpoint_error}")
             
-            final_state = await self.agent_executor.ainvoke(inputs, config)
+            final_state = await self.agent_executor.ainvoke(inputs, run_config)
             
             # 🔍 调试:打印 final_state 的所有 keys
             logger.info(f"🔍 Final state keys: {list(final_state.keys())}")
@@ -957,9 +1024,9 @@ class CustomReactAgent:
         if not self.checkpointer:
             return []
         
-        config = {"configurable": {"thread_id": thread_id}}
+        thread_config = {"configurable": {"thread_id": thread_id}}
         try:
-            conversation_state = await self.checkpointer.aget(config)
+            conversation_state = await self.checkpointer.aget(thread_config)
         except RuntimeError as e:
             if "Event loop is closed" in str(e):
                 logger.warning(f"⚠️ Event loop已关闭,尝试重新获取对话历史: {thread_id}")
@@ -1185,7 +1252,7 @@ not on explaining your decision-making process.
 2. 修复语法错误后,调用 valid_sql 工具重新验证
 3. 常见问题:缺少逗号、括号不匹配、关键词拼写错误"""
 
-        # 新增的合并条件,处理所有“不存在”类型的错误
+        # 新增的合并条件,处理所有"不存在"类型的错误
         elif ("不存在" in validation_error or 
               "no such table" in validation_error.lower() or
               "does not exist" in validation_error.lower()):

+ 3 - 0
react_agent/config.py

@@ -30,6 +30,9 @@ LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(messag
 # --- Agent 配置 ---
 DEFAULT_USER_ID = "guest"
 
+# --- StateGraph 配置 ---
+RECURSION_LIMIT = 100  # StateGraph递归限制
+
 # --- 网络重试配置 ---
 MAX_RETRIES = 3                    # 最大重试次数(减少以避免与OpenAI客户端冲突)
 RETRY_BASE_DELAY = 3               # 重试基础延迟(秒)

+ 466 - 0
unified_api.py

@@ -25,11 +25,16 @@ initialize_logging()
 # 标准 Flask 导入
 from flask import Flask, request, jsonify, session, send_file
 import redis.asyncio as redis
+from werkzeug.utils import secure_filename
 
 # 基础依赖
 import pandas as pd
 import json
 import sqlparse
+import tempfile
+import os
+import psycopg2
+import re
 
 # 项目模块导入
 from core.vanna_llm_factory import create_vanna_instance
@@ -1660,6 +1665,467 @@ def training_data_delete():
             response_text="删除训练数据失败,请稍后重试"
         )), 500
 
+@app.route('/api/v0/training_data/update', methods=['POST'])
+def training_data_update():
+    """更新训练数据API - 支持单条更新,采用先删除后插入策略"""
+    try:
+        req = request.get_json(force=True)
+        
+        # 1. 参数验证
+        original_id = req.get('id')
+        if not original_id:
+            return jsonify(bad_request_response(
+                response_text="缺少必需参数:id",
+                missing_params=["id"]
+            )), 400
+        
+        training_type = req.get('training_data_type')
+        if not training_type:
+            return jsonify(bad_request_response(
+                response_text="缺少必需参数:training_data_type",
+                missing_params=["training_data_type"]
+            )), 400
+        
+        # 2. 先删除原始记录
+        try:
+            success = vn.remove_training_data(original_id)
+            if not success:
+                return jsonify(bad_request_response(
+                    response_text=f"原始记录 {original_id} 不存在或删除失败"
+                )), 400
+        except Exception as e:
+            return jsonify(internal_error_response(
+                response_text=f"删除原始记录失败: {str(e)}"
+            )), 500
+        
+        # 3. 根据类型验证和准备新数据
+        try:
+            if training_type == 'sql':
+                sql = req.get('sql')
+                if not sql:
+                    return jsonify(bad_request_response(
+                        response_text="SQL字段是必需的",
+                        missing_params=["sql"]
+                    )), 400
+                
+                # SQL语法检查
+                is_valid, error_msg = validate_sql_syntax(sql)
+                if not is_valid:
+                    return jsonify(bad_request_response(
+                        response_text=f"SQL语法错误: {error_msg}"
+                    )), 400
+                
+                question = req.get('question')
+                if question:
+                    training_id = vn.train(question=question, sql=sql)
+                else:
+                    training_id = vn.train(sql=sql)
+                    
+            elif training_type == 'error_sql':
+                question = req.get('question')
+                sql = req.get('sql')
+                if not question or not sql:
+                    return jsonify(bad_request_response(
+                        response_text="question和sql字段都是必需的",
+                        missing_params=["question", "sql"]
+                    )), 400
+                training_id = vn.train_error_sql(question=question, sql=sql)
+                
+            elif training_type == 'documentation':
+                content = req.get('content')
+                if not content:
+                    return jsonify(bad_request_response(
+                        response_text="content字段是必需的",
+                        missing_params=["content"]
+                    )), 400
+                training_id = vn.train(documentation=content)
+                
+            elif training_type == 'ddl':
+                ddl = req.get('ddl')
+                if not ddl:
+                    return jsonify(bad_request_response(
+                        response_text="ddl字段是必需的",
+                        missing_params=["ddl"]
+                    )), 400
+                training_id = vn.train(ddl=ddl)
+                
+            else:
+                return jsonify(bad_request_response(
+                    response_text=f"不支持的训练数据类型: {training_type}"
+                )), 400
+                
+        except Exception as e:
+            return jsonify(internal_error_response(
+                response_text=f"创建新训练数据失败: {str(e)}"
+            )), 500
+        
+        # 4. 获取更新后的总记录数
+        current_total = get_total_training_count()
+        
+        return jsonify(success_response(
+            response_text="训练数据更新成功",
+            data={
+                "original_id": original_id,
+                "new_training_id": training_id,
+                "type": training_type,
+                "current_total_count": current_total
+            }
+        ))
+        
+    except Exception as e:
+        logger.error(f"training_data_update执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="更新训练数据失败,请稍后重试"
+        )), 500
+
+# 导入现有的专业训练函数
+from data_pipeline.trainer.run_training import (
+    train_ddl_statements,
+    train_documentation_blocks, 
+    train_json_question_sql_pairs,
+    train_formatted_question_sql_pairs,
+    train_sql_examples
+)
+
+def get_allowed_extensions(file_type: str) -> list:
+    """根据文件类型返回允许的扩展名"""
+    type_specific_extensions = {
+        'ddl': ['ddl', 'sql', 'txt', ''],  # 支持无扩展名
+        'markdown': ['md', 'markdown'],    # 不支持无扩展名
+        'sql_pair_json': ['json', 'txt', ''],  # 支持无扩展名
+        'sql_pair': ['sql', 'txt', ''],   # 支持无扩展名
+        'sql': ['sql', 'txt', '']         # 支持无扩展名
+    }
+    
+    return type_specific_extensions.get(file_type, [])
+
+def validate_file_content(content: str, file_type: str) -> dict:
+    """验证文件内容格式"""
+    try:
+        if file_type == 'ddl':
+            # 检查是否包含CREATE语句
+            if not re.search(r'\bCREATE\b', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合DDL格式,必须包含CREATE语句'}
+            
+        elif file_type == 'markdown':
+            # 检查是否包含##标题
+            if '##' not in content:
+                return {'valid': False, 'error': '文件内容不符合Markdown格式,必须包含##标题'}
+            
+        elif file_type == 'sql_pair_json':
+            # 检查是否为有效JSON
+            try:
+                data = json.loads(content)
+                if not isinstance(data, list) or not data:
+                    return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须是非空数组'}
+                
+                # 检查是否包含question和sql字段
+                for item in data:
+                    if not isinstance(item, dict):
+                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,数组元素必须是对象'}
+                    
+                    has_question = any(key.lower() == 'question' for key in item.keys())
+                    has_sql = any(key.lower() == 'sql' for key in item.keys())
+                    
+                    if not has_question or not has_sql:
+                        return {'valid': False, 'error': '文件内容不符合JSON问答对格式,必须包含question和sql字段'}
+                        
+            except json.JSONDecodeError:
+                return {'valid': False, 'error': '文件内容不符合JSON问答对格式,JSON格式错误'}
+            
+        elif file_type == 'sql_pair':
+            # 检查是否包含Question:和SQL:
+            if not re.search(r'\bQuestion\s*:', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含Question:'}
+            if not re.search(r'\bSQL\s*:', content, re.IGNORECASE):
+                return {'valid': False, 'error': '文件内容不符合问答对格式,必须包含SQL:'}
+            
+        elif file_type == 'sql':
+            # 检查是否包含;分隔符
+            if ';' not in content:
+                return {'valid': False, 'error': '文件内容不符合SQL格式,必须包含;分隔符'}
+        
+        return {'valid': True}
+        
+    except Exception as e:
+        return {'valid': False, 'error': f'文件内容验证失败: {str(e)}'}
+
+@app.route('/api/v0/training_data/upload', methods=['POST'])
+def upload_training_data():
+    """上传训练数据文件API - 支持多种文件格式的自动解析和导入"""
+    try:
+        # 1. 参数验证
+        if 'file' not in request.files:
+            return jsonify(bad_request_response("未提供文件"))
+        
+        file = request.files['file']
+        if file.filename == '':
+            return jsonify(bad_request_response("未选择文件"))
+        
+        # 获取file_type参数
+        file_type = request.form.get('file_type')
+        if not file_type:
+            return jsonify(bad_request_response("缺少必需参数:file_type"))
+        
+        # 验证file_type参数
+        valid_file_types = ['ddl', 'markdown', 'sql_pair_json', 'sql_pair', 'sql']
+        if file_type not in valid_file_types:
+            return jsonify(bad_request_response(f"不支持的文件类型:{file_type},支持的类型:{', '.join(valid_file_types)}"))
+        
+        # 2. 文件大小验证 (500KB)
+        file.seek(0, 2)
+        file_size = file.tell()
+        file.seek(0)
+        
+        if file_size > 500 * 1024:  # 500KB
+            return jsonify(bad_request_response("文件大小不能超过500KB"))
+        
+        # 3. 验证文件扩展名(基于file_type)
+        filename = secure_filename(file.filename)
+        allowed_extensions = get_allowed_extensions(file_type)
+        file_ext = filename.split('.')[-1].lower() if '.' in filename else ''
+        
+        if file_ext not in allowed_extensions:
+            # 构建友好的错误信息
+            non_empty_extensions = [ext for ext in allowed_extensions if ext]
+            if '' in allowed_extensions:
+                ext_message = f"{', '.join(non_empty_extensions)} 或无扩展名"
+            else:
+                ext_message = ', '.join(non_empty_extensions)
+            return jsonify(bad_request_response(f"文件类型 {file_type} 不支持的文件扩展名:{file_ext},支持的扩展名:{ext_message}"))
+        
+        # 4. 读取文件内容并验证格式
+        file.seek(0)
+        content = file.read().decode('utf-8')
+        
+        # 格式验证
+        validation_result = validate_file_content(content, file_type)
+        if not validation_result['valid']:
+            return jsonify(bad_request_response(validation_result['error']))
+        
+        # 5. 创建临时文件(复用现有函数需要文件路径)
+        temp_file_path = None
+        
+        try:
+            with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp', encoding='utf-8') as tmp_file:
+                tmp_file.write(content)
+                temp_file_path = tmp_file.name
+            
+            # 6. 根据文件类型调用现有的训练函数
+            if file_type == 'ddl':
+                train_ddl_statements(temp_file_path)
+                
+            elif file_type == 'markdown':
+                train_documentation_blocks(temp_file_path)
+                
+            elif file_type == 'sql_pair_json':
+                train_json_question_sql_pairs(temp_file_path)
+                
+            elif file_type == 'sql_pair':
+                train_formatted_question_sql_pairs(temp_file_path)
+                
+            elif file_type == 'sql':
+                train_sql_examples(temp_file_path)
+            
+            return jsonify(success_response(
+                response_text=f"文件上传并训练成功:{filename}",
+                data={
+                    "filename": filename,
+                    "file_type": file_type,
+                    "file_size": file_size,
+                    "status": "completed"
+                }
+            ))
+            
+        except Exception as e:
+            logger.error(f"训练失败: {str(e)}")
+            return jsonify(internal_error_response(f"训练失败: {str(e)}"))
+        finally:
+            # 清理临时文件
+            if temp_file_path and os.path.exists(temp_file_path):
+                try:
+                    os.unlink(temp_file_path)
+                except Exception as e:
+                    logger.warning(f"清理临时文件失败: {str(e)}")
+        
+    except Exception as e:
+        logger.error(f"文件上传失败: {str(e)}")
+        return jsonify(internal_error_response(f"文件上传失败: {str(e)}"))
+
+def get_db_connection():
+    """获取数据库连接"""
+    try:
+        from app_config import PGVECTOR_CONFIG
+        return psycopg2.connect(**PGVECTOR_CONFIG)
+    except Exception as e:
+        logger.error(f"数据库连接失败: {str(e)}")
+        raise
+
+def get_db_connection_for_transaction():
+    """获取用于事务操作的数据库连接(非自动提交模式)"""
+    try:
+        from app_config import PGVECTOR_CONFIG
+        conn = psycopg2.connect(**PGVECTOR_CONFIG)
+        conn.autocommit = False  # 设置为非自动提交模式,允许手动控制事务
+        return conn
+    except Exception as e:
+        logger.error(f"数据库连接失败: {str(e)}")
+        raise
+
+@app.route('/api/v0/training_data/combine', methods=['POST'])
+def combine_training_data():
+    """合并训练数据API - 支持合并重复记录"""
+    try:
+        # 1. 参数验证
+        data = request.get_json()
+        if not data:
+            return jsonify(bad_request_response("请求体不能为空"))
+        
+        collection_names = data.get('collection_names', [])
+        if not collection_names or not isinstance(collection_names, list):
+            return jsonify(bad_request_response("collection_names 参数必须是非空数组"))
+        
+        # 验证集合名称
+        valid_collections = ['sql', 'ddl', 'documentation', 'error_sql']
+        invalid_collections = [name for name in collection_names if name not in valid_collections]
+        if invalid_collections:
+            return jsonify(bad_request_response(f"不支持的集合名称: {invalid_collections}"))
+        
+        dry_run = data.get('dry_run', True)
+        keep_strategy = data.get('keep_strategy', 'first')
+        
+        if keep_strategy not in ['first', 'last', 'by_metadata_time']:
+            return jsonify(bad_request_response("keep_strategy 必须是 'first', 'last' 或 'by_metadata_time'"))
+        
+        # 2. 获取数据库连接(用于事务操作)
+        conn = get_db_connection_for_transaction()
+        cursor = conn.cursor()
+        
+        # 3. 查找重复记录
+        duplicate_groups = []
+        total_before = 0
+        total_duplicates = 0
+        collections_stats = {}
+        
+        for collection_name in collection_names:
+            # 获取集合ID
+            cursor.execute(
+                "SELECT uuid FROM langchain_pg_collection WHERE name = %s",
+                (collection_name,)
+            )
+            collection_result = cursor.fetchone()
+            if not collection_result:
+                continue
+            
+            collection_id = collection_result[0]
+            
+            # 统计该集合的记录数
+            cursor.execute(
+                "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_id = %s",
+                (collection_id,)
+            )
+            collection_before = cursor.fetchone()[0]
+            total_before += collection_before
+            
+            # 查找重复记录
+            if keep_strategy in ['first', 'last']:
+                order_by = "id"
+            else:
+                order_by = "COALESCE((cmetadata->>'createdat')::timestamp, '1970-01-01'::timestamp) DESC, id"
+            
+            cursor.execute(f"""
+                SELECT document, COUNT(*) as duplicate_count, 
+                       array_agg(id ORDER BY {order_by}) as record_ids
+                FROM langchain_pg_embedding 
+                WHERE collection_id = %s 
+                GROUP BY document 
+                HAVING COUNT(*) > 1
+            """, (collection_id,))
+            
+            collection_duplicates = 0
+            for row in cursor.fetchall():
+                document_content, duplicate_count, record_ids = row
+                collection_duplicates += duplicate_count - 1  # 减去要保留的一条
+                
+                # 根据保留策略选择要保留的记录
+                if keep_strategy == 'first':
+                    keep_id = record_ids[0]
+                    remove_ids = record_ids[1:]
+                elif keep_strategy == 'last':
+                    keep_id = record_ids[-1]
+                    remove_ids = record_ids[:-1]
+                else:  # by_metadata_time
+                    keep_id = record_ids[0]  # 已经按时间排序
+                    remove_ids = record_ids[1:]
+                
+                duplicate_groups.append({
+                    "collection_name": collection_name,
+                    "document_content": document_content[:100] + "..." if len(document_content) > 100 else document_content,
+                    "duplicate_count": duplicate_count,
+                    "kept_record_id" if not dry_run else "records_to_keep": keep_id,
+                    "removed_record_ids" if not dry_run else "records_to_remove": remove_ids
+                })
+            
+            total_duplicates += collection_duplicates
+            collections_stats[collection_name] = {
+                "before": collection_before,
+                "after": collection_before - collection_duplicates,
+                "duplicates_removed" if not dry_run else "duplicates_to_remove": collection_duplicates
+            }
+        
+        # 4. 执行合并操作(如果不是dry_run)
+        if not dry_run:
+            try:
+                # 连接已经设置为非自动提交模式,直接开始事务
+                for group in duplicate_groups:
+                    remove_ids = group["removed_record_ids"]
+                    if remove_ids:
+                        cursor.execute(
+                            "DELETE FROM langchain_pg_embedding WHERE id = ANY(%s)",
+                            (remove_ids,)
+                        )
+                
+                conn.commit()
+                
+            except Exception as e:
+                conn.rollback()
+                return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
+        
+        # 5. 构建响应
+        total_after = total_before - total_duplicates
+        
+        summary = {
+            "total_records_before": total_before,
+            "total_records_after": total_after,
+            "duplicates_removed" if not dry_run else "duplicates_to_remove": total_duplicates,
+            "collections_stats": collections_stats
+        }
+        
+        if dry_run:
+            response_text = f"发现 {total_duplicates} 条重复记录,预计删除后将从 {total_before} 条减少到 {total_after} 条记录"
+            data_key = "duplicate_groups"
+        else:
+            response_text = f"成功合并重复记录,删除了 {total_duplicates} 条重复记录,从 {total_before} 条减少到 {total_after} 条记录"
+            data_key = "merged_groups"
+        
+        return jsonify(success_response(
+            response_text=response_text,
+            data={
+                "dry_run": dry_run,
+                "collections_processed": collection_names,
+                "summary": summary,
+                data_key: duplicate_groups
+            }
+        ))
+        
+    except Exception as e:
+        return jsonify(internal_error_response(f"合并操作失败: {str(e)}"))
+    finally:
+        if 'cursor' in locals():
+            cursor.close()
+        if 'conn' in locals():
+            conn.close()
+
 # ==================== React Agent 扩展API ====================
 
 @app.route('/api/v0/react/users/<user_id>/conversations', methods=['GET'])