state.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # 在 agent/state.py 中更新 AgentState 定义
  2. from typing import TypedDict, Literal, Optional, List, Dict, Any
  3. class AgentState(TypedDict):
  4. """LangGraph Agent状态定义"""
  5. # 输入信息
  6. question: str
  7. session_id: Optional[str]
  8. # 上下文信息
  9. context_type: Optional[str] # 上下文类型 ("DATABASE" 或 "CHAT")
  10. # 分类结果
  11. question_type: Literal["DATABASE", "CHAT", "UNCERTAIN"]
  12. classification_confidence: float
  13. classification_reason: str
  14. classification_method: str # "rule_based_*", "enhanced_llm", "direct_database", "direct_chat", etc.
  15. # 数据库查询流程状态
  16. sql: Optional[str]
  17. sql_generation_attempts: int
  18. query_result: Optional[Dict[str, Any]]
  19. summary: Optional[str]
  20. # SQL验证和修复相关状态
  21. sql_generation_success: bool
  22. sql_validation_success: bool
  23. sql_repair_attempted: bool
  24. sql_repair_success: bool
  25. validation_error_type: Optional[str] # "forbidden_keywords" | "syntax_error" | None
  26. user_prompt: Optional[str]
  27. # 聊天响应
  28. chat_response: Optional[str]
  29. # 最终输出
  30. final_response: Dict[str, Any]
  31. # 错误处理
  32. error: Optional[str]
  33. error_code: Optional[int]
  34. # 流程控制
  35. current_step: str
  36. execution_path: List[str]
  37. retry_count: int
  38. max_retries: int
  39. # 调试信息
  40. debug_info: Dict[str, Any]
  41. # 路由模式相关
  42. routing_mode: Optional[str] # 记录使用的路由模式