Browse Source

准备重构database agent,把它一分为二。

wangxq 1 week ago
parent
commit
c26e9ffb72

+ 432 - 0
citu_app.py

@@ -11,6 +11,7 @@ import re
 import chainlit as cl
 import json
 from flask import session  # 添加session导入
+import sqlparse  # 用于SQL语法检查
 from common.redis_conversation_manager import RedisConversationManager  # 添加Redis对话管理器导入
 
 from common.qa_feedback_manager import QAFeedbackManager
@@ -2198,6 +2199,437 @@ def qa_cache_cleanup():
         )), 500
 
 
+# ==================== 训练数据管理接口 ====================
+
+def validate_sql_syntax(sql: str) -> tuple[bool, str]:
+    """SQL语法检查(仅对sql类型)"""
+    try:
+        parsed = sqlparse.parse(sql.strip())
+        
+        if not parsed or not parsed[0].tokens:
+            return False, "SQL语法错误:空语句"
+        
+        # 基本语法检查
+        sql_upper = sql.strip().upper()
+        if not any(sql_upper.startswith(keyword) for keyword in 
+                  ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
+            return False, "SQL语法错误:不是有效的SQL语句"
+        
+        # 安全检查:禁止危险的SQL操作
+        dangerous_operations = ['UPDATE', 'DELETE', 'ALERT', 'DROP']
+        for operation in dangerous_operations:
+            if sql_upper.startswith(operation):
+                return False, f'在训练集中禁止使用"{",".join(dangerous_operations)}"'
+        
+        return True, ""
+    except Exception as e:
+        return False, f"SQL语法错误:{str(e)}"
+
+def paginate_data(data_list: list, page: int, page_size: int):
+    """分页处理算法"""
+    total = len(data_list)
+    start_idx = (page - 1) * page_size
+    end_idx = start_idx + page_size
+    page_data = data_list[start_idx:end_idx]
+    
+    return {
+        "data": page_data,
+        "pagination": {
+            "page": page,
+            "page_size": page_size,
+            "total": total,
+            "total_pages": (total + page_size - 1) // page_size,
+            "has_next": end_idx < total,
+            "has_prev": page > 1
+        }
+    }
+
+def filter_by_type(data_list: list, training_data_type: str):
+    """按类型筛选算法"""
+    if not training_data_type:
+        return data_list
+    
+    return [
+        record for record in data_list 
+        if record.get('training_data_type') == training_data_type
+    ]
+
+def search_in_data(data_list: list, search_keyword: str):
+    """在数据中搜索关键词"""
+    if not search_keyword:
+        return data_list
+    
+    keyword_lower = search_keyword.lower()
+    return [
+        record for record in data_list
+        if (record.get('question') and keyword_lower in record['question'].lower()) or
+           (record.get('content') and keyword_lower in record['content'].lower())
+    ]
+
+def process_single_training_item(item: dict, index: int) -> dict:
+    """处理单个训练数据项"""
+    training_type = item.get('training_data_type')
+    
+    if training_type == 'sql':
+        sql = item.get('sql')
+        if not sql:
+            raise ValueError("SQL字段是必需的")
+        
+        # SQL语法检查
+        is_valid, error_msg = validate_sql_syntax(sql)
+        if not is_valid:
+            raise ValueError(error_msg)
+        
+        question = item.get('question')
+        if question:
+            training_id = vn.train(question=question, sql=sql)
+        else:
+            training_id = vn.train(sql=sql)
+            
+    elif training_type == 'error_sql':
+        # error_sql不需要语法检查
+        question = item.get('question')
+        sql = item.get('sql')
+        if not question or not sql:
+            raise ValueError("question和sql字段都是必需的")
+        training_id = vn.train_error_sql(question=question, sql=sql)
+        
+    elif training_type == 'documentation':
+        content = item.get('content')
+        if not content:
+            raise ValueError("content字段是必需的")
+        training_id = vn.train(documentation=content)
+        
+    elif training_type == 'ddl':
+        ddl = item.get('ddl')
+        if not ddl:
+            raise ValueError("ddl字段是必需的")
+        training_id = vn.train(ddl=ddl)
+        
+    else:
+        raise ValueError(f"不支持的训练数据类型: {training_type}")
+    
+    return {
+        "index": index,
+        "success": True,
+        "training_id": training_id,
+        "type": training_type,
+        "message": f"{training_type}训练数据创建成功"
+    }
+
+def get_total_training_count():
+    """获取当前训练数据总数"""
+    try:
+        training_data = vn.get_training_data()
+        if training_data is not None and not training_data.empty:
+            return len(training_data)
+        return 0
+    except Exception as e:
+        print(f"[WARNING] 获取训练数据总数失败: {e}")
+        return 0
+
+@app.flask_app.route('/api/v0/training_data/query', methods=['POST'])
+def training_data_query():
+    """
+    分页查询训练数据API
+    支持类型筛选、搜索和排序功能
+    """
+    try:
+        req = request.get_json(force=True)
+        
+        # 解析参数,设置默认值
+        page = req.get('page', 1)
+        page_size = req.get('page_size', 20)
+        training_data_type = req.get('training_data_type')
+        sort_by = req.get('sort_by', 'id')
+        sort_order = req.get('sort_order', 'desc')
+        search_keyword = req.get('search_keyword')
+        
+        # 参数验证
+        if page < 1:
+            return jsonify(bad_request_response(
+                response_text="页码必须大于0",
+                missing_params=["page"]
+            )), 400
+        
+        if page_size < 1 or page_size > 100:
+            return jsonify(bad_request_response(
+                response_text="每页大小必须在1-100之间",
+                missing_params=["page_size"]
+            )), 400
+        
+        if search_keyword and len(search_keyword) > 100:
+            return jsonify(bad_request_response(
+                response_text="搜索关键词最大长度为100字符",
+                missing_params=["search_keyword"]
+            )), 400
+        
+        # 获取训练数据
+        training_data = vn.get_training_data()
+        
+        if training_data is None or training_data.empty:
+            return jsonify(success_response(
+                response_text="查询成功,暂无训练数据",
+                data={
+                    "records": [],
+                    "pagination": {
+                        "page": page,
+                        "page_size": page_size,
+                        "total": 0,
+                        "total_pages": 0,
+                        "has_next": False,
+                        "has_prev": False
+                    },
+                    "filters_applied": {
+                        "training_data_type": training_data_type,
+                        "search_keyword": search_keyword
+                    }
+                }
+            ))
+        
+        # 转换为列表格式
+        records = training_data.to_dict(orient="records")
+        
+        # 应用筛选条件
+        if training_data_type:
+            records = filter_by_type(records, training_data_type)
+        
+        if search_keyword:
+            records = search_in_data(records, search_keyword)
+        
+        # 排序
+        if sort_by in ['id', 'training_data_type']:
+            reverse = (sort_order.lower() == 'desc')
+            records.sort(key=lambda x: x.get(sort_by, ''), reverse=reverse)
+        
+        # 分页
+        paginated_result = paginate_data(records, page, page_size)
+        
+        return jsonify(success_response(
+            response_text=f"查询成功,共找到 {paginated_result['pagination']['total']} 条记录",
+            data={
+                "records": paginated_result["data"],
+                "pagination": paginated_result["pagination"],
+                "filters_applied": {
+                    "training_data_type": training_data_type,
+                    "search_keyword": search_keyword
+                }
+            }
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] training_data_query执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="查询训练数据失败,请稍后重试"
+        )), 500
+
+@app.flask_app.route('/api/v0/training_data/create', methods=['POST'])
+def training_data_create():
+    """
+    创建训练数据API
+    支持单条和批量创建,支持四种数据类型
+    """
+    try:
+        req = request.get_json(force=True)
+        data = req.get('data')
+        
+        if not data:
+            return jsonify(bad_request_response(
+                response_text="缺少必需参数:data",
+                missing_params=["data"]
+            )), 400
+        
+        # 统一处理为列表格式
+        if isinstance(data, dict):
+            data_list = [data]
+        elif isinstance(data, list):
+            data_list = data
+        else:
+            return jsonify(bad_request_response(
+                response_text="data字段格式错误,应为对象或数组"
+            )), 400
+        
+        # 批量操作限制
+        if len(data_list) > 50:
+            return jsonify(bad_request_response(
+                response_text="批量操作最大支持50条记录"
+            )), 400
+        
+        results = []
+        successful_count = 0
+        type_summary = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
+        
+        for index, item in enumerate(data_list):
+            try:
+                result = process_single_training_item(item, index)
+                results.append(result)
+                if result['success']:
+                    successful_count += 1
+                    type_summary[result['type']] += 1
+            except Exception as e:
+                results.append({
+                    "index": index,
+                    "success": False,
+                    "type": item.get('training_data_type', 'unknown'),
+                    "error": str(e),
+                    "message": "创建失败"
+                })
+        
+        # 获取创建后的总记录数
+        current_total = get_total_training_count()
+        
+        return jsonify(success_response(
+            response_text="训练数据创建完成",
+            data={
+                "total_requested": len(data_list),
+                "successfully_created": successful_count,
+                "failed_count": len(data_list) - successful_count,
+                "results": results,
+                "summary": type_summary,
+                "current_total_count": current_total
+            }
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] training_data_create执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="创建训练数据失败,请稍后重试"
+        )), 500
+
+@app.flask_app.route('/api/v0/training_data/delete', methods=['POST'])
+def training_data_delete():
+    """
+    删除训练数据API
+    支持批量删除
+    """
+    try:
+        req = request.get_json(force=True)
+        ids = req.get('ids', [])
+        confirm = req.get('confirm', False)
+        
+        if not ids or not isinstance(ids, list):
+            return jsonify(bad_request_response(
+                response_text="缺少有效的ID列表",
+                missing_params=["ids"]
+            )), 400
+        
+        if not confirm:
+            return jsonify(bad_request_response(
+                response_text="删除操作需要确认,请设置confirm为true"
+            )), 400
+        
+        # 批量操作限制
+        if len(ids) > 50:
+            return jsonify(bad_request_response(
+                response_text="批量删除最大支持50条记录"
+            )), 400
+        
+        deleted_ids = []
+        failed_ids = []
+        failed_details = []
+        
+        for training_id in ids:
+            try:
+                success = vn.remove_training_data(training_id)
+                if success:
+                    deleted_ids.append(training_id)
+                else:
+                    failed_ids.append(training_id)
+                    failed_details.append({
+                        "id": training_id,
+                        "error": "记录不存在或删除失败"
+                    })
+            except Exception as e:
+                failed_ids.append(training_id)
+                failed_details.append({
+                    "id": training_id,
+                    "error": str(e)
+                })
+        
+        # 获取删除后的总记录数
+        current_total = get_total_training_count()
+        
+        return jsonify(success_response(
+            response_text="训练数据删除完成",
+            data={
+                "total_requested": len(ids),
+                "successfully_deleted": len(deleted_ids),
+                "failed_count": len(failed_ids),
+                "deleted_ids": deleted_ids,
+                "failed_ids": failed_ids,
+                "failed_details": failed_details,
+                "current_total_count": current_total
+            }
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] training_data_delete执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="删除训练数据失败,请稍后重试"
+        )), 500
+
+@app.flask_app.route('/api/v0/training_data/stats', methods=['GET'])
+def training_data_stats():
+    """
+    获取训练数据统计信息API
+    """
+    try:
+        training_data = vn.get_training_data()
+        
+        if training_data is None or training_data.empty:
+            return jsonify(success_response(
+                response_text="统计信息获取成功",
+                data={
+                    "total_count": 0,
+                    "type_breakdown": {
+                        "sql": 0,
+                        "documentation": 0,
+                        "ddl": 0,
+                        "error_sql": 0
+                    },
+                    "type_percentages": {
+                        "sql": 0.0,
+                        "documentation": 0.0,
+                        "ddl": 0.0,
+                        "error_sql": 0.0
+                    },
+                    "last_updated": datetime.now().isoformat()
+                }
+            ))
+        
+        total_count = len(training_data)
+        
+        # 统计各类型数量
+        type_breakdown = {"sql": 0, "documentation": 0, "ddl": 0, "error_sql": 0}
+        
+        if 'training_data_type' in training_data.columns:
+            type_counts = training_data['training_data_type'].value_counts()
+            for data_type, count in type_counts.items():
+                if data_type in type_breakdown:
+                    type_breakdown[data_type] = int(count)
+        
+        # 计算百分比
+        type_percentages = {}
+        for data_type, count in type_breakdown.items():
+            type_percentages[data_type] = round(count / max(total_count, 1) * 100, 2)
+        
+        return jsonify(success_response(
+            response_text="统计信息获取成功",
+            data={
+                "total_count": total_count,
+                "type_breakdown": type_breakdown,
+                "type_percentages": type_percentages,
+                "last_updated": datetime.now().isoformat()
+            }
+        ))
+        
+    except Exception as e:
+        print(f"[ERROR] training_data_stats执行失败: {str(e)}")
+        return jsonify(internal_error_response(
+            response_text="获取统计信息失败,请稍后重试"
+        )), 500
+
+
 @app.flask_app.route('/api/v0/cache_overview_full', methods=['GET'])
 def cache_overview_full():
     """获取所有缓存系统的综合概览"""

+ 559 - 0
docs/Database Agent节点重构概要设计.md

@@ -0,0 +1,559 @@
+# Database节点重构概要设计文档
+
+## 🎯 重构目标
+
+将当前的 `_agent_database_node` 拆分为两个独立节点,实现SQL生成验证与执行的分离,提供中间路由能力,提升系统的智能性和用户体验。
+
+## 📋 整体架构变更
+
+### **重构前架构**
+```
+classify_question → agent_database → format_response → END
+                    (generate_sql + execute_sql + summary)
+```
+
+### **重构后架构**
+```
+classify_question → agent_sql_generation_node → agent_sql_execution_node → format_response → END
+                    (generate_sql + validation)    (execute_sql + summary)
+                           ↓ (失败路由)
+                    format_response → END
+```
+
+## 🔧 节点详细设计
+
+### **1. _agent_sql_generation_node (SQL生成验证节点)**
+
+#### **功能职责**
+- 调用 `generate_sql()` 工具生成SQL
+- 使用复用schema_tools的验证逻辑进行SQL验证
+- 根据验证结果决定路由方向
+
+#### **核心逻辑**
+```python
+def _agent_sql_generation_node(self, state: AgentState) -> AgentState:
+    """SQL生成验证节点"""
+    try:
+        question = state["question"]
+        
+        # 步骤1: SQL生成
+        sql_result = generate_sql(question, allow_llm_to_see_data=True)
+        
+        if not sql_result.get("success"):
+            # SQL生成失败处理
+            return self._handle_sql_generation_failure(state, sql_result)
+        
+        sql = sql_result.get("sql")
+        state["sql"] = sql
+        
+        # 步骤2: SQL验证 (如果启用)
+        if self._is_sql_validation_enabled():
+            validation_result = await self._validate_sql_with_schema_tools(sql)
+            
+            if not validation_result.get("valid"):
+                # 验证失败,尝试修复
+                return await self._handle_sql_validation_failure(state, sql, validation_result)
+        
+        # 生成和验证都成功
+        state["sql_generation_success"] = True
+        state["execution_path"].append("agent_sql_generation")
+        return state
+        
+    except Exception as e:
+        state["error"] = f"SQL生成节点异常: {str(e)}"
+        return state
+```
+
+#### **SQL验证集成 (复用schema_tools)**
+```python
+async def _validate_sql_with_schema_tools(self, sql: str) -> Dict[str, Any]:
+    """复用schema_tools的SQL验证逻辑"""
+    try:
+        # 1. 语法验证 (EXPLAIN SQL)
+        syntax_valid = await self._validate_sql_syntax(sql)
+        if not syntax_valid.get("valid"):
+            return {
+                "valid": False,
+                "error_type": "syntax_error",
+                "error_message": syntax_valid.get("error"),
+                "can_repair": True
+            }
+        
+        # 2. 禁止词检查
+        forbidden_check = self._check_forbidden_keywords(sql)
+        if not forbidden_check.get("valid"):
+            return {
+                "valid": False,
+                "error_type": "forbidden_keywords",
+                "error_message": forbidden_check.get("error"),
+                "can_repair": False
+            }
+        
+        return {"valid": True}
+        
+    except Exception as e:
+        return {
+            "valid": False,
+            "error_type": "validation_exception",
+            "error_message": str(e),
+            "can_repair": False
+        }
+
+def _check_forbidden_keywords(self, sql: str) -> Dict[str, Any]:
+    """检查禁止的SQL关键词"""
+    forbidden_keywords = ['UPDATE', 'DELETE', 'DROP', 'ALTER', 'INSERT']
+    sql_upper = sql.upper()
+    
+    for keyword in forbidden_keywords:
+        if keyword in sql_upper:
+            return {
+                "valid": False,
+                "error": f"不允许的操作: {keyword}。本系统只支持查询操作(SELECT)。"
+            }
+    
+    return {"valid": True}
+
+async def _validate_sql_syntax(self, sql: str) -> Dict[str, Any]:
+    """语法验证 - 复用schema_tools逻辑"""
+    try:
+        # 获取数据库连接 (复用现有连接逻辑)
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        
+        # 执行EXPLAIN验证
+        explain_sql = f"EXPLAIN {sql}"
+        # 注意: 这里需要适配到实际的数据库连接方式
+        result = await vn.run_sql(explain_sql)
+        
+        return {"valid": True}
+        
+    except Exception as e:
+        return {
+            "valid": False,
+            "error": str(e)
+        }
+```
+
+#### **SQL修复逻辑 (复用schema_tools)**
+```python
+async def _handle_sql_validation_failure(self, state: AgentState, sql: str, validation_result: Dict) -> AgentState:
+    """处理SQL验证失败"""
+    error_type = validation_result.get("error_type")
+    
+    # 禁止词错误,直接失败
+    if error_type == "forbidden_keywords":
+        state["sql_generation_success"] = False
+        state["user_prompt"] = validation_result.get("error_message")
+        return state
+    
+    # 语法错误,尝试LLM修复 (只修复一次)
+    if error_type == "syntax_error" and self._is_auto_repair_enabled():
+        repaired_sql = await self._repair_sql_with_llm(sql, validation_result.get("error_message"))
+        
+        if repaired_sql:
+            # 再次验证修复后的SQL
+            revalidation = await self._validate_sql_with_schema_tools(repaired_sql)
+            
+            if revalidation.get("valid"):
+                state["sql"] = repaired_sql
+                state["sql_generation_success"] = True
+                state["sql_repair_applied"] = True
+                return state
+    
+    # 修复失败或不支持修复
+    state["sql_generation_success"] = False
+    state["user_prompt"] = f"SQL生成遇到问题: {validation_result.get('error_message')}"
+    return state
+
+```python
+async def _repair_sql_with_llm(self, sql: str, error_message: str) -> Optional[str]:
+    """使用LLM修复SQL - 只尝试一次"""
+    try:
+        from common.vanna_instance import get_vanna_instance
+        vn = get_vanna_instance()
+        
+        # 构建修复提示词
+        repair_prompt = f"""你是一个PostgreSQL SQL专家,请修复以下SQL语句的语法错误。
+
+当前数据库类型: PostgreSQL
+错误信息: {error_message}
+
+需要修复的SQL:
+{sql}
+
+修复要求:
+1. 只修复语法错误和表结构错误
+2. 保持SQL的原始业务逻辑不变  
+3. 使用PostgreSQL标准语法
+4. 确保修复后的SQL语法正确
+
+请直接输出修复后的SQL语句,不要添加其他说明文字。"""
+
+        # 调用LLM修复 - 复用schema_tools的异步调用方式
+        response = await asyncio.to_thread(
+            vn.chat_with_llm,
+            question=repair_prompt,
+            system_prompt="你是一个专业的PostgreSQL SQL专家,专门负责修复SQL语句中的语法错误。"
+        )
+        
+        if response and response.strip():
+            return response.strip()
+        
+        return None
+        
+    except Exception as e:
+        print(f"[ERROR] SQL修复失败: {str(e)}")
+        return None
+
+async def _handle_sql_validation_failure(self, state: AgentState, sql: str, validation_result: Dict) -> AgentState:
+    """处理SQL验证失败 - 重要约束:只修复一次"""
+    error_type = validation_result.get("error_type")
+    
+    # 禁止词错误,直接失败,不尝试修复
+    if error_type == "forbidden_keywords":
+        state["sql_generation_success"] = False
+        state["user_prompt"] = validation_result.get("error_message")
+        state["execution_path"].append("forbidden_keywords_failed")
+        return state
+    
+    # 语法错误,仅尝试修复一次
+    if error_type == "syntax_error" and self._is_auto_repair_enabled():
+        print(f"[SQL_REPAIR] 尝试修复SQL语法错误(仅一次): {validation_result.get('error_message')}")
+        
+        repaired_sql = await self._repair_sql_with_llm(sql, validation_result.get("error_message"))
+        
+        if repaired_sql:
+            # 对修复后的SQL进行验证 - 不管结果如何,不再重试
+            revalidation = await self._validate_sql_with_schema_tools(repaired_sql)
+            
+            if revalidation.get("valid"):
+                # 修复成功
+                state["sql"] = repaired_sql
+                state["sql_generation_success"] = True
+                state["sql_repair_applied"] = True
+                state["execution_path"].append("sql_repair_success")
+                print(f"[SQL_REPAIR] SQL修复成功")
+                return state
+            else:
+                # 修复后仍然失败,直接结束
+                print(f"[SQL_REPAIR] 修复后验证仍然失败: {revalidation.get('error_message')}")
+                state["sql_generation_success"] = False
+                state["user_prompt"] = f"SQL修复尝试失败: {revalidation.get('error_message')}"
+                state["execution_path"].append("sql_repair_failed")
+                return state
+        else:
+            # LLM修复失败
+            print(f"[SQL_REPAIR] LLM修复调用失败")
+            state["sql_generation_success"] = False
+            state["user_prompt"] = f"SQL语法修复失败: {validation_result.get('error_message')}"
+            state["execution_path"].append("sql_repair_failed")
+            return state
+    
+    # 不启用修复或其他错误类型,直接失败
+    state["sql_generation_success"] = False
+    state["user_prompt"] = f"SQL验证失败: {validation_result.get('error_message')}"
+    state["execution_path"].append("sql_validation_failed")
+    return state
+```
+
+### **2. _agent_sql_execution_node (SQL执行节点)**
+
+#### **功能职责**
+- 执行已验证的SQL语句
+- 根据配置决定是否生成摘要
+- 保持原有的执行逻辑
+
+#### **核心逻辑**
+```python
+def _agent_sql_execution_node(self, state: AgentState) -> AgentState:
+    """SQL执行节点 - 保持原有逻辑"""
+    try:
+        sql = state.get("sql")
+        question = state["question"]
+        
+        # 步骤1: 执行SQL (复用原有逻辑)
+        execute_result = execute_sql.invoke({"sql": sql})
+        
+        if not execute_result.get("success"):
+            state["error"] = execute_result.get("error", "SQL执行失败")
+            return state
+        
+        query_result = execute_result.get("data_result")
+        state["query_result"] = query_result
+        
+        # 步骤2: 生成摘要 (根据配置)
+        if self._should_generate_summary(query_result):
+            original_question = self._extract_original_question(question)
+            
+            summary_result = generate_summary.invoke({
+                "question": original_question,
+                "query_result": query_result,
+                "sql": sql
+            })
+            
+            if summary_result.get("success"):
+                state["summary"] = summary_result.get("summary")
+            else:
+                # 摘要生成失败不是致命错误
+                state["summary"] = f"查询执行完成,共返回 {query_result.get('row_count', 0)} 条记录。"
+        
+        state["execution_path"].append("agent_sql_execution")
+        return state
+        
+    except Exception as e:
+        state["error"] = f"SQL执行节点异常: {str(e)}"
+        return state
+
+def _should_generate_summary(self, query_result: Dict) -> bool:
+    """判断是否应该生成摘要"""
+    from app_config import ENABLE_RESULT_SUMMARY
+    return ENABLE_RESULT_SUMMARY and query_result.get('row_count', 0) > 0
+```
+
+## 🔀 条件路由设计
+
+### **SQL生成节点的条件路由**
+```python
+def _route_after_sql_generation(self, state: AgentState) -> Literal["continue_execution", "return_to_user"]:
+    """SQL生成后的路由决策"""
+    
+    if state.get("sql_generation_success"):
+        return "continue_execution"  # 路由到SQL执行节点
+    else:
+        return "return_to_user"      # 路由到format_response,结束流程
+```
+
+### **工作流配置更新**
+```python
+def _create_workflow(self, routing_mode: str = None) -> StateGraph:
+    """更新工作流创建逻辑"""
+    workflow = StateGraph(AgentState)
+    
+    # 添加新的节点
+    workflow.add_node("classify_question", self._classify_question_node)
+    workflow.add_node("agent_sql_generation", self._agent_sql_generation_node)
+    workflow.add_node("agent_sql_execution", self._agent_sql_execution_node)
+    workflow.add_node("agent_chat", self._agent_chat_node)
+    workflow.add_node("format_response", self._format_response_node)
+    
+    # 设置条件路由
+    workflow.add_conditional_edges(
+        "classify_question",
+        self._route_after_classification,
+        {
+            "DATABASE": "agent_sql_generation",
+            "CHAT": "agent_chat"
+        }
+    )
+    
+    # SQL生成后的条件路由
+    workflow.add_conditional_edges(
+        "agent_sql_generation",
+        self._route_after_sql_generation,
+        {
+            "continue_execution": "agent_sql_execution",
+            "return_to_user": "format_response"
+        }
+    )
+    
+    # 普通边缘
+    workflow.add_edge("agent_sql_execution", "format_response")
+    workflow.add_edge("agent_chat", "format_response")
+    workflow.add_edge("format_response", END)
+    
+    return workflow.compile()
+```
+
+## ⚙️ 配置参数设计
+
+### **新增配置参数 - 精简版**
+```python
+# 在app_config.py中添加
+SQL_VALIDATION_CONFIG = {
+    "enable_syntax_validation": True,      # 是否启用语法验证(EXPLAIN SQL)
+    "enable_forbidden_check": True,       # 是否启用禁止词检查  
+    "enable_auto_repair": True,           # 是否启用自动修复(只尝试一次)
+}
+
+# 现有配置保持不变
+ENABLE_RESULT_SUMMARY = True  # 控制摘要生成
+```
+
+### **配置使用逻辑 - 明确约束**
+```python
+def _is_sql_validation_enabled(self) -> bool:
+    """检查是否启用SQL验证"""
+    # 注意:任一验证功能启用都算启用验证
+    return (SQL_VALIDATION_CONFIG.get("enable_syntax_validation", False) or 
+            SQL_VALIDATION_CONFIG.get("enable_forbidden_check", False))
+
+def _is_auto_repair_enabled(self) -> bool:
+    """检查是否启用自动修复"""
+    # 只有在语法验证启用的情况下,自动修复才有意义
+    return (SQL_VALIDATION_CONFIG.get("enable_auto_repair", False) and 
+            SQL_VALIDATION_CONFIG.get("enable_syntax_validation", False))
+
+def _should_skip_validation(self) -> bool:
+    """判断是否跳过所有验证"""
+    # 当所有验证功能都禁用时,跳过验证步骤
+    return not self._is_sql_validation_enabled()
+```
+
+### **验证策略的完整逻辑**
+```python
+# 验证流程的完整决策树
+if not self._is_sql_validation_enabled():
+    # 跳过所有验证,直接使用生成的SQL
+    pass
+else:
+    # 按优先级执行验证
+    if SQL_VALIDATION_CONFIG.get("enable_syntax_validation"):
+        # 1. 语法验证 (EXPLAIN SQL)
+        syntax_result = await self._validate_sql_syntax(sql)
+        if not syntax_result.valid and self._is_auto_repair_enabled():
+            # 尝试修复 (只一次)
+            repaired_sql = await self._repair_sql_with_llm(sql, syntax_result.error)
+            # 修复后不管成功失败,都不再重试
+    
+    if SQL_VALIDATION_CONFIG.get("enable_forbidden_check"):
+        # 2. 禁止词检查 (不可修复)
+        forbidden_result = self._check_forbidden_keywords(sql)
+        if not forbidden_result.valid:
+            # 直接失败,不尝试修复
+            return self._handle_forbidden_keywords_error(state, forbidden_result)
+```
+
+## 📊 状态字段更新
+
+### **AgentState新增字段**
+```python
+class AgentState(TypedDict):
+    # 现有字段保持不变...
+    
+    # 新增字段
+    sql_generation_success: bool           # SQL生成是否成功
+    sql_repair_applied: bool              # 是否应用了SQL修复
+    user_prompt: Optional[str]            # 给用户的提示信息
+```
+
+## 🔄 错误处理和用户提示
+
+### **SQL生成失败的情况处理**
+```python
+```python
+def _handle_sql_generation_failure(self, state: AgentState, sql_result: Dict) -> AgentState:
+    """处理SQL生成失败 - 统一处理三种情况"""
+    error_message = sql_result.get("error", "")
+    error_type = sql_result.get("error_type", "")
+    
+    # 重要设计决策:不进行二次分类判断,统一按数据库问题处理
+    # 原因:第一次LLM分类已经判断为DATABASE,第二次大概率仍是DATABASE
+    
+    # 根据错误类型和内容生成统一的用户提示
+    if "no relevant tables" in error_message.lower() or "table not found" in error_message.lower():
+        # 情况1:数据库缺少表/字段
+        user_prompt = "数据库中没有相关的表或字段信息,请您提供更多具体信息或修改问题。"
+        failure_reason = "missing_database_info"
+    elif "ambiguous" in error_message.lower() or "more information" in error_message.lower():
+        # 情况2:问题太模糊  
+        user_prompt = "您的问题需要更多信息才能准确查询,请提供更详细的描述。"
+        failure_reason = "ambiguous_question"
+    elif error_type == "llm_explanation":
+        # 情况3:LLM返回解释性文本而非SQL
+        user_prompt = error_message + " 请尝试重新描述您的问题或询问其他内容。"
+        failure_reason = "llm_explanation"
+    else:
+        # 其他未分类的失败情况
+        user_prompt = "无法生成有效的SQL查询,请尝试重新描述您的问题。"
+        failure_reason = "unknown_generation_failure"
+    
+    # 关键决策:所有失败都返回用户提示,不路由到CHAT
+    state["sql_generation_success"] = False
+    state["user_prompt"] = user_prompt
+    state["sql_generation_failure_reason"] = failure_reason
+    state["execution_path"].append("sql_generation_failed")
+    
+    print(f"[SQL_GENERATION] 生成失败: {failure_reason} - {user_prompt}")
+    return state
+```
+
+### **format_response节点的适配**
+```python
+def _format_response_node(self, state: AgentState) -> AgentState:
+    """格式化响应节点 - 适配新的失败处理"""
+    
+    # 处理SQL生成失败的情况
+    if not state.get("sql_generation_success", True) and state.get("user_prompt"):
+        state["final_response"] = {
+            "success": False,
+            "response": state["user_prompt"],
+            "type": "DATABASE",
+            "sql_generation_failed": True,
+            "execution_path": state["execution_path"],
+            "classification_info": {
+                "confidence": state.get("classification_confidence", 0),
+                "reason": state.get("classification_reason", ""),
+                "method": state.get("classification_method", "")
+            }
+        }
+        return state
+    
+    # 其他情况保持原有逻辑
+    # ... (原有的format_response逻辑)
+```
+
+## 🚀 实施计划
+
+### **阶段1: 基础重构**
+1. 创建 `_agent_sql_generation_node` 节点
+2. 创建 `_agent_sql_execution_node` 节点  
+3. 更新工作流配置和条件路由
+4. 基础功能测试
+
+### **阶段2: 验证集成**
+1. 集成schema_tools的SQL验证逻辑
+2. 实现SQL修复功能
+3. 添加配置参数控制
+4. 验证功能测试
+
+### **阶段3: 错误处理优化**
+1. 完善错误分类和用户提示
+2. 优化format_response节点适配
+3. 用户体验测试和优化
+
+### **阶段4: 全面测试**
+1. 各种路由模式兼容性测试
+2. 边界情况和异常处理测试
+3. 性能和稳定性测试
+
+## 🔍 重要设计细节和约束
+
+### **SQL修复的执行限制**
+- **修复次数限制**:SQL语法修复只执行一次,不进行多次重试
+- **修复范围限制**:只修复语法错误和表结构错误,不修改业务逻辑
+- **修复失败处理**:如果修复后仍无法通过验证,直接返回错误给用户
+
+### **验证流程的优先级**
+1. **语法验证优先**:先进行EXPLAIN SQL验证
+2. **禁止词检查**:通过语法验证后检查禁止的操作词
+3. **修复策略**:只对语法错误尝试修复,禁止词错误直接失败
+
+### **错误处理的统一策略**
+- **三种失败情况合并处理**:数据库缺少表/字段、问题模糊、无法判定都统一返回用户提示
+- **不进行二次分类**:坚持第一次分类结果,不因SQL生成失败而重新路由到CHAT
+- **提示信息明确**:根据具体错误原因给出针对性的用户指导
+
+### **配置参数的作用范围**
+- **验证开关独立**:语法验证和禁止词检查可独立控制
+- **修复功能可选**:可以只验证不修复,由配置决定
+- **全局生效**:所有路由模式(包括database_direct)都遵循验证配置
+
+### **节点内部处理约束**
+- **原子性保证**:每个节点的处理对LangGraph来说是原子的
+- **状态完整性**:节点间通过状态传递所有必要信息
+- **错误不中断流程**:验证或修复失败不抛异常,通过状态标记处理
+
+### **与现有架构的兼容性**
+- **工具函数不变**:继续使用现有的@tool装饰的函数
+- **状态结构兼容**:新增字段不影响现有状态处理逻辑
+- **路由模式兼容**:database_direct、chat_direct、hybrid模式都支持新流程

+ 527 - 0
docs/训练数据管理API概要设计文档.md

@@ -0,0 +1,527 @@
+# 训练数据管理API概要设计文档
+
+## 📋 概述
+
+本文档描述了训练数据管理系统的API设计方案,提供完整的CRUD操作接口,支持分页查询、类型筛选、批量操作等功能。该系统旨在为AI训练数据提供统一的管理入口。
+
+### 🎯 设计目标
+- **统一管理**:提供训练数据的统一管理接口
+- **类型支持**:支持SQL、文档、DDL、错误SQL四种训练数据类型
+- **批量操作**:支持批量创建和删除操作
+- **性能优化**:支持分页查询和类型筛选
+- **数据统计**:提供详细的数据统计信息
+
+### 🔧 基础信息
+- **基础URL**: `http://localhost:5000`
+- **API前缀**: `/api/v0/training_data/`
+- **数据格式**: JSON
+- **字符编码**: UTF-8
+- **命名规范**: 统一使用动词命名(query/create/delete/stats)
+
+---
+
+## 🚀 API端点一览
+
+| API端点 | 方法 | 功能描述 |
+|---------|------|----------|
+| `/api/v0/training_data/query` | POST | 分页查询训练数据(支持类型筛选和搜索) |
+| `/api/v0/training_data/create` | POST | 创建训练数据(支持单条和批量) |
+| `/api/v0/training_data/delete` | POST | 删除训练数据(支持批量删除) |
+| `/api/v0/training_data/stats` | GET | 获取训练数据统计信息 |
+
+---
+
+## 📖 详细API设计
+
+### 1. 分页查询API
+
+**端点**: `POST /api/v0/training_data/query`
+
+**功能**: 分页查询训练数据,支持类型筛选、搜索和排序功能。
+
+#### 📝 请求参数
+
+| 参数名 | 类型 | 必填 | 默认值 | 说明 |
+|--------|------|------|--------|------|
+| `page` | int | 否 | 1 | 页码(从1开始) |
+| `page_size` | int | 否 | 20 | 每页记录数(范围:1-100) |
+| `training_data_type` | string | 否 | null | 筛选类型:sql/documentation/ddl/error_sql |
+| `sort_by` | string | 否 | "id" | 排序字段:id/training_data_type |
+| `sort_order` | string | 否 | "desc" | 排序方向:asc/desc |
+| `search_keyword` | string | 否 | null | 搜索关键词(在question/content中搜索) |
+
+#### 🌰 请求示例
+
+**基础查询**:
+```json
+{
+  "page": 1,
+  "page_size": 20
+}
+```
+
+**筛选查询**:
+```json
+{
+  "page": 1,
+  "page_size": 20,
+  "training_data_type": "sql",
+  "search_keyword": "用户",
+  "sort_by": "id",
+  "sort_order": "desc"
+}
+```
+
+#### ✅ 成功响应格式
+
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "查询成功,共找到 156 条记录",
+  "data": {
+    "records": [
+      {
+        "id": "uuid-123-sql",
+        "training_data_type": "sql",
+        "question": "查询所有用户信息",
+        "content": "SELECT * FROM users",
+        "created_at": "2024-06-24T10:30:00"
+      },
+      {
+        "id": "uuid-456-doc",
+        "training_data_type": "documentation", 
+        "question": null,
+        "content": "用户表包含用户的基本信息...",
+        "created_at": "2024-06-24T11:00:00"
+      }
+    ],
+    "pagination": {
+      "page": 1,
+      "page_size": 20,
+      "total": 156,
+      "total_pages": 8,
+      "has_next": true,
+      "has_prev": false
+    },
+    "filters_applied": {
+      "training_data_type": "sql",
+      "search_keyword": "用户"
+    }
+  }
+}
+```
+
+---
+
+### 2. 创建训练数据API
+
+**端点**: `POST /api/v0/training_data/create`
+
+**功能**: 创建训练数据,支持单条和批量创建,支持四种数据类型。
+
+#### 📝 请求参数
+
+**单条记录**:
+```json
+{
+  "data": {
+    "training_data_type": "sql",
+    "question": "查询所有用户信息",
+    "sql": "SELECT * FROM users"
+  }
+}
+```
+
+**批量记录**:
+```json
+{
+  "data": [
+    {
+      "training_data_type": "sql",
+      "question": "查询所有用户信息", 
+      "sql": "SELECT * FROM users"
+    },
+    {
+      "training_data_type": "documentation",
+      "content": "用户表包含用户的基本信息..."
+    },
+    {
+      "training_data_type": "ddl",
+      "ddl": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100));"
+    },
+    {
+      "training_data_type": "error_sql",
+      "question": "查询用户",
+      "sql": "SELECT * FROM user"
+    }
+  ]
+}
+```
+
+#### 📋 各类型字段要求
+
+| 类型 | 必填字段 | 可选字段 | 说明 |
+|------|----------|----------|------|
+| `sql` | `sql` | `question` | 如果不提供question会自动生成,SQL会进行语法检查 |
+| `error_sql` | `question`, `sql` | 无 | 错误的SQL示例,不进行语法检查 |
+| `documentation` | `content` | 无 | 文档内容,不进行格式检查 |
+| `ddl` | `ddl` | 无 | DDL语句,不进行语法检查 |
+
+#### ✅ 成功响应格式
+
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "response": "训练数据创建完成",
+    "total_requested": 4,
+    "successfully_created": 3,
+    "failed_count": 1,
+    "results": [
+      {
+        "index": 0,
+        "success": true,
+        "training_id": "uuid-123-sql",
+        "type": "sql",
+        "message": "SQL训练数据创建成功"
+      },
+      {
+        "index": 1,
+        "success": true,
+        "training_id": "uuid-456-doc",
+        "type": "documentation", 
+        "message": "文档训练数据创建成功"
+      },
+      {
+        "index": 2,
+        "success": true,
+        "training_id": "uuid-789-ddl",
+        "type": "ddl",
+        "message": "DDL训练数据创建成功"
+      },
+      {
+        "index": 3,
+        "success": false,
+        "type": "error_sql",
+        "error": "创建失败:缺少必填字段question",
+        "message": "创建失败"
+      }
+    ],
+    "summary": {
+      "sql": 1,
+      "documentation": 1,
+      "ddl": 1,
+      "error_sql": 0
+    },
+    "current_total_count": 159
+  }
+}
+```
+
+---
+
+### 3. 删除训练数据API
+
+**端点**: `POST /api/v0/training_data/delete`
+
+**功能**: 删除指定的训练数据记录,支持批量删除。
+
+#### 📝 请求参数
+
+| 参数名 | 类型 | 必填 | 说明 |
+|--------|------|------|------|
+| `ids` | array[string] | 是 | 要删除的训练数据ID列表 |
+| `confirm` | boolean | 是 | 确认删除标志,必须为true |
+
+#### 🌰 请求示例
+
+```json
+{
+  "ids": ["uuid-123-sql", "uuid-456-doc", "uuid-789-ddl"],
+  "confirm": true
+}
+```
+
+#### ✅ 成功响应格式
+
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "删除操作完成",
+  "data": {
+    "response": "训练数据删除完成",
+    "total_requested": 3,
+    "successfully_deleted": 2,
+    "failed_count": 1,
+    "deleted_ids": ["uuid-123-sql", "uuid-456-doc"],
+    "failed_ids": ["uuid-789-ddl"],
+    "failed_details": [
+      {
+        "id": "uuid-789-ddl",
+        "error": "记录不存在"
+      }
+    ],
+    "current_total_count": 157
+  }
+}
+```
+
+---
+
+### 4. 统计信息API
+
+**端点**: `GET /api/v0/training_data/stats`
+
+**功能**: 获取训练数据的统计信息,用于监控和分析。
+
+#### 🌰 请求示例
+
+```
+GET /api/v0/training_data/stats
+```
+
+#### ✅ 成功响应格式
+
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "统计信息获取成功",
+  "data": {
+    "response": "统计信息获取成功",
+    "total_count": 156,
+    "type_breakdown": {
+      "sql": 45,
+      "documentation": 38,
+      "ddl": 52,
+      "error_sql": 21
+    },
+    "type_percentages": {
+      "sql": 28.85,
+      "documentation": 24.36,
+      "ddl": 33.33,
+      "error_sql": 13.46
+    },
+    "last_updated": "2024-06-24T15:30:00"
+  }
+}
+```
+
+---
+
+## 🔧 技术实现要点
+
+### 1. 数据源集成
+
+#### 1.1 查询数据源
+- 使用现有的 `vn.get_training_data()` 方法获取训练数据
+- 基于返回的DataFrame进行分页和筛选处理
+- 根据ID后缀判断训练数据类型:
+  - `-sql` → sql类型
+  - `-doc` → documentation类型
+  - `-ddl` → ddl类型
+  - `-error_sql` → error_sql类型
+
+#### 1.2 创建数据源
+- **SQL类型**:调用 `vn.train(question=question, sql=sql)` 或 `vn.train(sql=sql)`
+- **错误SQL类型**:调用 `vn.train_error_sql(question=question, sql=sql)`
+- **文档类型**:调用 `vn.train(documentation=content)`
+- **DDL类型**:调用 `vn.train(ddl=ddl)`
+
+#### 1.3 删除数据源
+- 使用 `custompgvector/pgvector.py` 中的 `remove_training_data(id)` 方法
+
+### 2. 核心算法设计
+
+#### 2.1 分页算法
+```python
+def paginate_data(data_list: list, page: int, page_size: int):
+    """分页处理算法"""
+    total = len(data_list)
+    start_idx = (page - 1) * page_size
+    end_idx = start_idx + page_size
+    page_data = data_list[start_idx:end_idx]
+    
+    return {
+        "data": page_data,
+        "pagination": {
+            "page": page,
+            "page_size": page_size,
+            "total": total,
+            "total_pages": (total + page_size - 1) // page_size,
+            "has_next": end_idx < total,
+            "has_prev": page > 1
+        }
+    }
+```
+
+#### 2.2 类型筛选算法
+```python
+def filter_by_type(data_list: list, training_data_type: str):
+    """按类型筛选算法"""
+    if not training_data_type:
+        return data_list
+    
+    return [
+        record for record in data_list 
+        if record.get('training_data_type') == training_data_type
+    ]
+```
+
+#### 2.3 SQL语法检查算法
+```python
+def validate_sql_syntax(sql: str) -> tuple[bool, str]:
+    """SQL语法检查(仅对sql类型)"""
+    try:
+        import sqlparse
+        parsed = sqlparse.parse(sql.strip())
+        
+        if not parsed or not parsed[0].tokens:
+            return False, "SQL语法错误:空语句"
+        
+        # 基本语法检查
+        sql_upper = sql.strip().upper()
+        if not any(sql_upper.startswith(keyword) for keyword in 
+                  ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP']):
+            return False, "SQL语法错误:不是有效的SQL语句"
+        
+        return True, ""
+    except Exception as e:
+        return False, f"SQL语法错误:{str(e)}"
+```
+
+### 3. 性能和安全考虑
+
+#### 3.1 性能优化
+- **分页限制**:最大页面大小限制为100条记录
+- **批量限制**:批量操作最大支持50条记录
+- **查询缓存**:考虑对频繁查询结果进行缓存
+- **异步处理**:大批量操作考虑异步处理
+
+#### 3.2 安全考虑
+- **参数验证**:严格验证所有输入参数
+- **SQL注入防护**:对SQL内容进行安全检查
+- **删除确认**:删除操作必须提供确认标志
+- **权限控制**:预留权限验证接口
+
+#### 3.3 错误处理
+- **统一错误格式**:使用项目标准错误响应格式
+- **批量操作错误**:部分成功时提供详细的成功/失败信息
+- **数据库异常**:妥善处理数据库连接和操作异常
+
+---
+
+## 🔄 集成方案
+
+### 1. 代码集成
+- **主要文件**:`citu_app.py` - 添加新的API路由
+- **响应格式**:复用 `common/result.py` 中的标准响应格式
+- **数据库连接**:复用现有的Vanna实例和数据库连接
+- **错误处理**:遵循项目现有的错误处理规范
+
+### 2. 依赖关系
+```
+训练数据管理API
+├── vn.get_training_data()          # 查询数据源
+├── vn.train()                      # 创建训练数据
+├── vn.train_error_sql()            # 创建错误SQL
+├── vn.remove_training_data()       # 删除数据
+└── common/result.py                # 响应格式
+```
+
+### 3. 配置要求
+- **数据库连接**:确保PgVector或ChromaDB连接正常
+- **Vanna实例**:确保Vanna实例初始化完成
+- **依赖库**:sqlparse(用于SQL语法检查)
+
+---
+
+## 📊 使用场景示例
+
+### 1. 典型工作流程
+
+**步骤1:查看统计信息**
+```bash
+GET /api/v0/training_data/stats
+```
+
+**步骤2:查询现有数据**
+```json
+POST /api/v0/training_data/query
+{
+  "page": 1,
+  "page_size": 50,
+  "training_data_type": "sql"
+}
+```
+
+**步骤3:批量添加训练数据**
+```json
+POST /api/v0/training_data/create
+{
+  "data": [
+    {
+      "training_data_type": "sql",
+      "question": "查询活跃用户",
+      "sql": "SELECT * FROM users WHERE status = 'active'"
+    },
+    {
+      "training_data_type": "documentation",
+      "content": "用户状态字段说明:active表示活跃用户..."
+    }
+  ]
+}
+```
+
+**步骤4:清理无效数据**
+```json
+POST /api/v0/training_data/delete
+{
+  "ids": ["uuid-invalid-1", "uuid-invalid-2"],
+  "confirm": true
+}
+```
+
+### 2. 数据迁移场景
+适用于从其他系统批量导入训练数据,支持不同类型的混合导入。
+
+### 3. 数据清理场景
+适用于定期清理低质量或过时的训练数据,维护数据集质量。
+
+---
+
+## ⚠️ 注意事项
+
+### 1. 限制说明
+- 分页查询每页最大100条记录
+- 批量操作最大50条记录
+- 搜索关键词最大长度100字符
+- SQL语法检查仅适用于sql类型
+
+### 2. 兼容性
+- 需要确保Vanna实例支持所有调用的方法
+- 数据库版本兼容性(PgVector扩展)
+- Python依赖库版本要求
+
+### 3. 监控建议
+- 记录API调用日志
+- 监控批量操作性能
+- 跟踪数据质量指标
+- 设置异常告警机制
+
+---
+
+## 📝 更新记录
+
+| 版本 | 日期 | 更新内容 | 作者 |
+|------|------|----------|------|
+| 1.0 | 2024-06-24 | 初始版本设计 | AI Assistant |
+
+---
+
+**文档状态**: 概要设计完成  
+**下一步**: 详细设计和开发实现 

+ 447 - 0
docs/训练数据管理API调用说明.md

@@ -0,0 +1,447 @@
+# 训练数据管理API调用说明
+
+## 概述
+
+训练数据管理API提供了完整的训练数据CRUD操作,支持SQL、DDL、文档和错误SQL四种数据类型的管理。所有API都采用统一的响应格式,并提供详细的错误信息。
+
+**基础URL:** `http://localhost:8084/api/v0`
+
+## API列表
+
+| API端点 | 方法 | 功能描述 |
+|---------|------|----------|
+| `/training_data/stats` | GET | 获取训练数据统计信息 |
+| `/training_data/query` | POST | 分页查询训练数据,支持筛选和搜索 |
+| `/training_data/create` | POST | 创建训练数据,支持单条和批量操作 |
+| `/training_data/delete` | POST | 删除训练数据,支持批量操作 |
+
+---
+
+## 1. 获取统计信息
+
+### 请求信息
+```http
+GET /api/v0/training_data/stats
+```
+
+### 请求参数
+无需参数
+
+### 响应示例
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "total_count": 228,
+    "type_breakdown": {
+      "sql": 210,
+      "ddl": 9,
+      "documentation": 8,
+      "error_sql": 1
+    },
+    "type_percentages": {
+      "sql": 92.11,
+      "ddl": 3.95,
+      "documentation": 3.51,
+      "error_sql": 0.44
+    },
+    "response": "统计信息获取成功",
+    "last_updated": "2025-06-24T17:39:36.895114"
+  }
+}
+```
+
+### 响应字段说明
+- `total_count`: 训练数据总数
+- `type_breakdown`: 各类型数据的具体数量
+- `type_percentages`: 各类型数据的百分比(保留2位小数)
+- `last_updated`: 最后更新时间
+
+---
+
+## 2. 查询训练数据
+
+### 请求信息
+```http
+POST /api/v0/training_data/query
+```
+
+### 请求参数
+```json
+{
+  "page": 1,                    // 页码,必须大于0,默认1
+  "page_size": 20,              // 每页大小,1-100之间,默认20
+  "training_data_type": "sql",  // 可选,筛选类型:"sql"|"ddl"|"documentation"|"error_sql"
+  "search_keyword": "用户",     // 可选,搜索关键词,最大100字符
+  "sort_by": "id",              // 可选,排序字段:"id"|"training_data_type",默认"id"
+  "sort_order": "desc"          // 可选,排序方向:"asc"|"desc",默认"desc"
+}
+```
+
+### 响应示例
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "records": [
+      {
+        "id": "fb113c5e-6cde-4653-ac5f-7558f6e634db-sql",
+        "training_data_type": "sql",
+        "question": "查看活跃用户列表",
+        "content": "SELECT user_id, user_name, last_login FROM users WHERE last_login >= CURRENT_DATE - INTERVAL '30 days';"
+      }
+    ],
+    "pagination": {
+      "page": 1,
+      "page_size": 5,
+      "total": 2,
+      "total_pages": 1,
+      "has_next": false,
+      "has_prev": false
+    },
+    "filters_applied": {
+      "training_data_type": "sql",
+      "search_keyword": "用户"
+    },
+    "response": "查询成功,共找到 2 条记录"
+  }
+}
+```
+
+### 错误响应示例
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "error_type": "missing_required_params",
+    "missing_params": ["page"],
+    "response": "页码必须大于0",
+    "timestamp": "2025-06-24T17:41:47.486749"
+  }
+}
+```
+
+---
+
+## 3. 创建训练数据
+
+### 请求信息
+```http
+POST /api/v0/training_data/create
+```
+
+### 请求参数
+
+#### 单条创建
+```json
+{
+  "data": {
+    "training_data_type": "sql",
+    "question": "查询所有用户",
+    "sql": "SELECT * FROM users WHERE delete_ts IS NULL"
+  }
+}
+```
+
+#### 批量创建
+```json
+{
+  "data": [
+    {
+      "training_data_type": "sql",
+      "question": "查询活跃用户",
+      "sql": "SELECT * FROM users WHERE status = 'active'"
+    },
+    {
+      "training_data_type": "documentation",
+      "content": "用户表用于存储系统用户的基本信息和状态。"
+    },
+    {
+      "training_data_type": "ddl",
+      "ddl": "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100));"
+    },
+    {
+      "training_data_type": "error_sql",
+      "question": "错误的查询示例",
+      "sql": "SELCT * FROM users"
+    }
+  ]
+}
+```
+
+### 数据类型字段要求
+
+| 类型 | 必需字段 | 可选字段 | 说明 |
+|------|----------|----------|------|
+| `sql` | `training_data_type`, `question`, `sql` | - | SQL查询训练数据 |
+| `documentation` | `training_data_type`, `content` | - | 文档说明训练数据 |
+| `ddl` | `training_data_type`, `ddl` | - | DDL语句训练数据 |
+| `error_sql` | `training_data_type`, `question`, `sql` | - | 错误SQL示例数据 |
+
+### 响应示例
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "total_requested": 1,
+    "successfully_created": 1,
+    "failed_count": 0,
+    "results": [
+      {
+        "index": 0,
+        "success": true,
+        "type": "sql",
+        "training_id": "e1afe1c2-6956-4133-9cb6-0f83c5e1b12d-sql",
+        "message": "sql训练数据创建成功"
+      }
+    ],
+    "summary": {
+      "sql": 1,
+      "ddl": 0,
+      "documentation": 0,
+      "error_sql": 0
+    },
+    "current_total_count": 229,
+    "response": "训练数据创建完成"
+  }
+}
+```
+
+### SQL安全检查
+
+系统会自动检查SQL语句,禁止以下危险操作:
+- `UPDATE`:数据更新操作
+- `DELETE`:数据删除操作
+- `DROP`:表删除操作
+- `ALERT`:表结构修改操作
+
+如果检测到危险操作,会返回错误:
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "total_requested": 1,
+    "successfully_created": 0,
+    "failed_count": 1,
+    "results": [
+      {
+        "index": 0,
+        "success": false,
+        "type": "sql",
+        "error": "在训练集中禁止使用\"UPDATE,DELETE,ALERT,DROP\"",
+        "message": "创建失败"
+      }
+    ]
+  }
+}
+```
+
+### 批量操作限制
+- 单次批量操作最多支持50条记录
+- 超出限制会返回400错误
+
+---
+
+## 4. 删除训练数据
+
+### 请求信息
+```http
+POST /api/v0/training_data/delete
+```
+
+### 请求参数
+```json
+{
+  "ids": [
+    "e1afe1c2-6956-4133-9cb6-0f83c5e1b12d-sql",
+    "0db3b76a-6fa5-4c8e-9115-3ec7cc6159fe-doc"
+  ],
+  "confirm": true  // 必须为true,安全确认机制
+}
+```
+
+### 参数说明
+- `ids`: 要删除的训练数据ID数组,必需
+- `confirm`: 删除确认,必须为`true`,否则返回400错误
+
+### 响应示例
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    "total_requested": 2,
+    "successfully_deleted": 1,
+    "failed_count": 1,
+    "deleted_ids": [
+      "e1afe1c2-6956-4133-9cb6-0f83c5e1b12d-sql"
+    ],
+    "failed_ids": [
+      "0db3b76a-6fa5-4c8e-9115-3ec7cc6159fe-doc"
+    ],
+    "failed_details": [
+      {
+        "id": "0db3b76a-6fa5-4c8e-9115-3ec7cc6159fe-doc",
+        "error": "记录不存在或删除失败"
+      }
+    ],
+    "current_total_count": 228,
+    "response": "训练数据删除完成"
+  }
+}
+```
+
+### 确认机制错误
+```json
+{
+  "code": 400,
+  "success": false,
+  "message": "请求参数错误",
+  "data": {
+    "error_type": "missing_required_params",
+    "response": "删除操作需要确认,请设置confirm为true",
+    "timestamp": "2025-06-24T17:39:58.501962"
+  }
+}
+```
+
+### 批量操作限制
+- 单次批量删除最多支持50条记录
+- 超出限制会返回400错误
+
+---
+
+## 通用响应格式
+
+### 成功响应
+```json
+{
+  "code": 200,
+  "success": true,
+  "message": "操作成功",
+  "data": {
+    // 具体的响应数据
+  }
+}
+```
+
+### 错误响应
+```json
+{
+  "code": 400|500|503,
+  "success": false,
+  "message": "错误类型描述",
+  "data": {
+    "error_type": "错误类型标识",
+    "response": "用户友好的错误信息",
+    "timestamp": "错误发生时间",
+    // 其他错误相关字段
+  }
+}
+```
+
+## 错误码说明
+
+| 状态码 | 含义 | 常见场景 |
+|--------|------|----------|
+| 200 | 成功 | 请求正常处理 |
+| 400 | 请求参数错误 | 参数验证失败、缺少必需参数 |
+| 500 | 系统内部错误 | 数据库错误、系统异常 |
+| 503 | 服务不可用 | 系统维护、组件异常 |
+
+## 使用示例
+
+### Python调用示例
+```python
+import requests
+import json
+
+BASE_URL = "http://localhost:8084/api/v0"
+
+# 1. 获取统计信息
+def get_stats():
+    response = requests.get(f"{BASE_URL}/training_data/stats")
+    return response.json()
+
+# 2. 查询数据
+def query_data(page=1, page_size=20, keyword=None, data_type=None):
+    data = {"page": page, "page_size": page_size}
+    if keyword:
+        data["search_keyword"] = keyword
+    if data_type:
+        data["training_data_type"] = data_type
+    
+    response = requests.post(f"{BASE_URL}/training_data/query", json=data)
+    return response.json()
+
+# 3. 创建数据
+def create_data(training_data):
+    response = requests.post(f"{BASE_URL}/training_data/create", 
+                           json={"data": training_data})
+    return response.json()
+
+# 4. 删除数据
+def delete_data(ids):
+    response = requests.post(f"{BASE_URL}/training_data/delete",
+                           json={"ids": ids, "confirm": True})
+    return response.json()
+
+# 使用示例
+if __name__ == "__main__":
+    # 获取统计
+    stats = get_stats()
+    print(f"总数据量: {stats['data']['total_count']}")
+    
+    # 查询SQL类型数据
+    results = query_data(data_type="sql", keyword="用户")
+    print(f"找到 {results['data']['pagination']['total']} 条记录")
+    
+    # 创建新数据
+    new_data = {
+        "training_data_type": "sql",
+        "question": "查询测试用户",
+        "sql": "SELECT * FROM users WHERE status = 'test'"
+    }
+    create_result = create_data(new_data)
+    if create_result['data']['successfully_created'] > 0:
+        created_id = create_result['data']['results'][0]['training_id']
+        print(f"创建成功,ID: {created_id}")
+        
+        # 删除刚创建的数据
+        delete_result = delete_data([created_id])
+        print(f"删除成功: {delete_result['data']['successfully_deleted']} 条")
+```
+
+## 注意事项
+
+1. **安全性**:
+   - SQL类型数据会进行语法检查和安全检查
+   - 禁止UPDATE、DELETE、DROP、ALERT等危险操作
+   - 删除操作需要明确确认(confirm=true)
+
+2. **性能考虑**:
+   - 查询API支持分页,建议合理设置page_size
+   - 批量操作限制在50条以内
+   - 搜索关键词限制100字符以内
+
+3. **数据类型**:
+   - 确保为不同类型的训练数据提供正确的字段
+   - SQL和error_sql类型需要question和sql字段
+   - documentation类型需要content字段
+   - ddl类型需要ddl字段
+
+4. **错误处理**:
+   - 始终检查响应的success字段
+   - 批量操作可能部分成功,需要检查具体结果
+   - 关注failed_count和failed_details获取失败详情 

+ 2 - 1
requirements.txt

@@ -5,4 +5,5 @@ langchain-core==0.3.64
 langchain-postgres==0.0.14
 langgraph==0.4.8
 langchain==0.3.23
-redis==5.0.1
+redis==5.0.1
+sqlparse==0.4.4

+ 180 - 0
test_training_data_apis.py

@@ -0,0 +1,180 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+训练数据管理API测试脚本
+用于测试新增的训练数据管理接口
+"""
+
+import requests
+import json
+import sys
+
+# API基础URL
+BASE_URL = "http://localhost:8084"
+API_PREFIX = "/api/v0/training_data"
+
+def test_api(method: str, endpoint: str, data=None, expected_status=200):
+    """测试API的通用函数"""
+    url = f"{BASE_URL}{API_PREFIX}{endpoint}"
+    
+    try:
+        if method == "GET":
+            response = requests.get(url)
+        elif method == "POST":
+            response = requests.post(url, json=data, headers={'Content-Type': 'application/json'})
+        elif method == "DELETE":
+            response = requests.delete(url, json=data, headers={'Content-Type': 'application/json'})
+        else:
+            print(f"❌ 不支持的HTTP方法: {method}")
+            return False
+        
+        print(f"📤 {method} {endpoint}")
+        if data:
+            print(f"📋 请求数据: {json.dumps(data, ensure_ascii=False, indent=2)}")
+        
+        print(f"📥 状态码: {response.status_code}")
+        
+        if response.status_code == expected_status:
+            print("✅ 状态码正确")
+        else:
+            print(f"⚠️ 期望状态码: {expected_status}, 实际状态码: {response.status_code}")
+        
+        try:
+            response_json = response.json()
+            print(f"📄 响应: {json.dumps(response_json, ensure_ascii=False, indent=2)}")
+            return True
+        except:
+            print(f"📄 响应: {response.text}")
+            return False
+            
+    except requests.ConnectionError:
+        print(f"❌ 连接失败: 请确保服务器运行在 {BASE_URL}")
+        return False
+    except Exception as e:
+        print(f"❌ 请求失败: {str(e)}")
+        return False
+
+def main():
+    """主测试函数"""
+    print("🚀 开始测试训练数据管理API...")
+    print(f"🔗 服务器地址: {BASE_URL}")
+    print("="*60)
+    
+    # 1. 测试统计API (GET)
+    print("\n📊 测试统计API")
+    test_api("GET", "/stats")
+    
+    # 2. 测试查询API (POST) - 基础查询
+    print("\n🔍 测试查询API - 基础查询")
+    test_api("POST", "/query", {
+        "page": 1,
+        "page_size": 10
+    })
+    
+    # 3. 测试查询API (POST) - 带筛选
+    print("\n🔍 测试查询API - 带筛选")
+    test_api("POST", "/query", {
+        "page": 1,
+        "page_size": 5,
+        "training_data_type": "sql",
+        "search_keyword": "用户"
+    })
+    
+    # 4. 测试创建API (POST) - 单条SQL记录
+    print("\n➕ 测试创建API - 单条SQL记录")
+    test_api("POST", "/create", {
+        "data": {
+            "training_data_type": "sql",
+            "question": "查询所有测试用户",
+            "sql": "SELECT * FROM users WHERE status = 'test'"
+        }
+    })
+    
+    # 5. 测试创建API (POST) - 批量记录
+    print("\n➕ 测试创建API - 批量记录")
+    test_api("POST", "/create", {
+        "data": [
+            {
+                "training_data_type": "documentation",
+                "content": "这是一个测试文档,用于说明用户表的结构和用途。"
+            },
+            {
+                "training_data_type": "ddl",
+                "ddl": "CREATE TABLE test_table (id INT PRIMARY KEY, name VARCHAR(100));"
+            }
+        ]
+    })
+    
+    # 6. 测试创建API (POST) - SQL语法错误
+    print("\n➕ 测试创建API - SQL语法错误")
+    test_api("POST", "/create", {
+        "data": {
+            "training_data_type": "sql",
+            "question": "测试错误SQL",
+            "sql": "INVALID SQL SYNTAX"
+        }
+    }, expected_status=200)  # 批量操作中的错误仍返回200,但results中会有错误信息
+    
+    # 6.1. 测试创建API (POST) - 危险SQL操作检查
+    print("\n➕ 测试创建API - 危险SQL操作检查")
+    test_api("POST", "/create", {
+        "data": [
+            {
+                "training_data_type": "sql",
+                "question": "测试UPDATE操作",
+                "sql": "UPDATE users SET status = 'inactive' WHERE id = 1"
+            },
+            {
+                "training_data_type": "sql",
+                "question": "测试DELETE操作",
+                "sql": "DELETE FROM users WHERE id = 1"
+            },
+            {
+                "training_data_type": "sql",
+                "question": "测试DROP操作",
+                "sql": "DROP TABLE test_table"
+            }
+        ]
+    }, expected_status=200)  # 批量操作返回200,但会有错误信息
+    
+    # 7. 测试删除API (POST) - 不存在的ID
+    print("\n🗑️ 测试删除API - 不存在的ID")
+    test_api("POST", "/delete", {
+        "ids": ["non-existent-id-1", "non-existent-id-2"],
+        "confirm": True
+    })
+    
+    # 8. 测试删除API (POST) - 缺少确认
+    print("\n🗑️ 测试删除API - 缺少确认")
+    test_api("POST", "/delete", {
+        "ids": ["test-id"],
+        "confirm": False
+    }, expected_status=400)
+    
+    # 9. 测试参数验证 - 页码错误
+    print("\n⚠️ 测试参数验证 - 页码错误")
+    test_api("POST", "/query", {
+        "page": 0,
+        "page_size": 10
+    }, expected_status=400)
+    
+    # 10. 测试参数验证 - 页面大小错误
+    print("\n⚠️ 测试参数验证 - 页面大小错误")
+    test_api("POST", "/query", {
+        "page": 1,
+        "page_size": 150
+    }, expected_status=400)
+    
+    print(f"\n{'='*60}")
+    print("🎯 测试完成!")
+    print("\n📝 说明:")
+    print("- ✅ 表示API响应正常")
+    print("- ⚠️ 表示状态码不符合预期")
+    print("- ❌ 表示连接或请求失败")
+    print("\n💡 提示:")
+    print("- 首次运行时可能没有训练数据,这是正常的")
+    print("- 创建操作成功后,再次查询可以看到新增的数据")
+    print("- 删除不存在的ID会返回成功,但failed_count会显示失败数量")
+
+if __name__ == "__main__":
+    main()