shell.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """
  2. 重构后的 CustomReactAgent 的交互式命令行客户端
  3. """
  4. import asyncio
  5. import logging
  6. import sys
  7. import os
  8. import json
  9. # 动态地将项目根目录添加到 sys.path,以支持跨模块导入
  10. # 这使得脚本更加健壮,无论从哪里执行
  11. PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
  12. sys.path.insert(0, PROJECT_ROOT)
  13. # 从新模块导入 Agent 和配置
  14. try:
  15. # 相对导入(当作为模块导入时)
  16. from .agent import CustomReactAgent
  17. from . import config
  18. except ImportError:
  19. # 绝对导入(当直接运行时)
  20. from test.custom_react_agent.agent import CustomReactAgent
  21. from test.custom_react_agent import config
  22. # 配置日志
  23. logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT)
  24. logger = logging.getLogger(__name__)
  25. class CustomAgentShell:
  26. """新 Agent 的交互式 Shell 客户端"""
  27. def __init__(self, agent: CustomReactAgent):
  28. """私有构造函数,请使用 create() 类方法。"""
  29. self.agent = agent
  30. self.user_id: str = config.DEFAULT_USER_ID
  31. self.thread_id: str | None = None
  32. @classmethod
  33. async def create(cls):
  34. """异步工厂方法,创建 Shell 实例。"""
  35. agent = await CustomReactAgent.create()
  36. return cls(agent)
  37. async def close(self):
  38. """关闭 Agent 资源。"""
  39. if self.agent:
  40. await self.agent.close()
  41. async def start(self):
  42. """启动 Shell 界面。"""
  43. print("\n🚀 Custom React Agent Shell (StateGraph Version)")
  44. print("=" * 50)
  45. # 获取用户ID
  46. user_input = input(f"请输入您的用户ID (默认: {self.user_id}): ").strip()
  47. if user_input:
  48. self.user_id = user_input
  49. print(f"👤 当前用户: {self.user_id}")
  50. # 这里可以加入显示历史会话的逻辑
  51. print("\n💬 开始对话 (输入 'exit' 或 'quit' 退出)")
  52. print("-" * 50)
  53. await self._chat_loop()
  54. async def _chat_loop(self):
  55. """主要的聊天循环。"""
  56. while True:
  57. user_input = input(f"👤 [{self.user_id[:8]}]> ").strip()
  58. if not user_input:
  59. continue
  60. if user_input.lower() in ['quit', 'exit']:
  61. raise KeyboardInterrupt # 优雅退出
  62. if user_input.lower() == 'new':
  63. self.thread_id = None
  64. print("🆕 已开始新会话。")
  65. continue
  66. if user_input.lower() == 'history':
  67. await self._show_current_history()
  68. continue
  69. # 正常对话
  70. print("🤖 Agent 正在思考...")
  71. result = await self.agent.chat(user_input, self.user_id, self.thread_id)
  72. if result.get("success"):
  73. answer = result.get('answer', '')
  74. # 去除 [Formatted Output] 标记,只显示真正的回答
  75. if answer.startswith("[Formatted Output]\n"):
  76. answer = answer.replace("[Formatted Output]\n", "")
  77. print(f"🤖 Agent: {answer}")
  78. # 如果包含 SQL 数据,也显示出来
  79. if 'sql_data' in result:
  80. print(f"📊 SQL 查询结果: {result['sql_data']}")
  81. # 更新 thread_id 以便在同一会话中继续
  82. self.thread_id = result.get("thread_id")
  83. else:
  84. print(f"❌ 发生错误: {result.get('error')}")
  85. async def _show_current_history(self):
  86. """显示当前会话的历史记录。"""
  87. if not self.thread_id:
  88. print("当前没有活跃的会话。请先开始对话。")
  89. return
  90. print(f"\n--- 对话历史: {self.thread_id} ---")
  91. history = await self.agent.get_conversation_history(self.thread_id)
  92. if not history:
  93. print("无法获取历史或历史为空。")
  94. return
  95. for msg in history:
  96. print(f"[{msg['type']}] {msg['content']}")
  97. print("--- 历史结束 ---")
  98. async def main():
  99. """主函数入口"""
  100. shell = None
  101. try:
  102. shell = await CustomAgentShell.create()
  103. await shell.start()
  104. except KeyboardInterrupt:
  105. logger.info("\n👋 检测到退出指令,正在清理资源...")
  106. except Exception as e:
  107. logger.error(f"❌ 程序发生严重错误: {e}", exc_info=True)
  108. finally:
  109. if shell:
  110. await shell.close()
  111. print("✅ 程序已成功关闭。")
  112. if __name__ == "__main__":
  113. try:
  114. asyncio.run(main())
  115. except KeyboardInterrupt:
  116. # 这个捕获是为了处理在 main 之外的 Ctrl+C
  117. print("\n👋 程序被强制退出。")