shell.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. 重构后的 CustomReactAgent 的交互式命令行客户端
  3. """
  4. # from __future__ import annotations
  5. import asyncio
  6. import logging
  7. import sys
  8. import os
  9. import json
  10. # 将当前目录和项目根目录添加到 sys.path
  11. CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
  12. PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, '..', '..'))
  13. sys.path.insert(0, CURRENT_DIR) # 当前目录优先
  14. sys.path.insert(1, PROJECT_ROOT) # 项目根目录
  15. # 导入 Agent 和配置(简化版本)
  16. from agent import CustomReactAgent
  17. import config
  18. # 配置日志
  19. logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT)
  20. logger = logging.getLogger(__name__)
  21. class CustomAgentShell:
  22. """新 Agent 的交互式 Shell 客户端"""
  23. def __init__(self, agent: CustomReactAgent):
  24. """私有构造函数,请使用 create() 类方法。"""
  25. self.agent = agent
  26. self.user_id: str = config.DEFAULT_USER_ID
  27. self.thread_id: str | None = None
  28. @classmethod
  29. async def create(cls):
  30. """异步工厂方法,创建 Shell 实例。"""
  31. agent = await CustomReactAgent.create()
  32. return cls(agent)
  33. async def close(self):
  34. """关闭 Agent 资源。"""
  35. if self.agent:
  36. await self.agent.close()
  37. async def start(self):
  38. """启动 Shell 界面。"""
  39. print("\n🚀 Custom React Agent Shell (StateGraph Version)")
  40. print("=" * 50)
  41. # 获取用户ID
  42. user_input = input(f"请输入您的用户ID (默认: {self.user_id}): ").strip()
  43. if user_input:
  44. self.user_id = user_input
  45. print(f"👤 当前用户: {self.user_id}")
  46. # 这里可以加入显示历史会话的逻辑
  47. print("\n💬 开始对话 (输入 'exit' 或 'quit' 退出)")
  48. print("-" * 50)
  49. await self._chat_loop()
  50. async def _chat_loop(self):
  51. """主要的聊天循环。"""
  52. while True:
  53. user_input = input(f"👤 [{self.user_id[:8]}]> ").strip()
  54. if not user_input:
  55. continue
  56. if user_input.lower() in ['quit', 'exit']:
  57. raise KeyboardInterrupt # 优雅退出
  58. if user_input.lower() == 'new':
  59. self.thread_id = None
  60. print("🆕 已开始新会话。")
  61. continue
  62. if user_input.lower() == 'history':
  63. await self._show_current_history()
  64. continue
  65. # 正常对话
  66. print("🤖 Agent 正在思考...")
  67. result = await self.agent.chat(user_input, self.user_id, self.thread_id)
  68. if result.get("success"):
  69. answer = result.get('answer', '')
  70. # 去除 [Formatted Output] 标记,只显示真正的回答
  71. if answer.startswith("[Formatted Output]\n"):
  72. answer = answer.replace("[Formatted Output]\n", "")
  73. print(f"🤖 Agent: {answer}")
  74. # 如果包含 SQL 数据,也显示出来
  75. if 'sql_data' in result:
  76. print(f"📊 SQL 查询结果: {result['sql_data']}")
  77. # 更新 thread_id 以便在同一会话中继续
  78. self.thread_id = result.get("thread_id")
  79. else:
  80. error_msg = result.get('error', '未知错误')
  81. print(f"❌ 发生错误: {error_msg}")
  82. # 提供针对性的建议
  83. if "Connection error" in error_msg or "网络" in error_msg:
  84. print("💡 建议:")
  85. print(" - 检查网络连接是否正常")
  86. print(" - 稍后重试该问题")
  87. print(" - 如果问题持续,可以尝试重新启动程序")
  88. elif "timeout" in error_msg.lower():
  89. print("💡 建议:")
  90. print(" - 当前网络较慢,建议稍后重试")
  91. print(" - 尝试简化问题复杂度")
  92. else:
  93. print("💡 建议:")
  94. print(" - 请检查问题格式是否正确")
  95. print(" - 尝试重新描述您的问题")
  96. # 保持thread_id,用户可以继续对话
  97. if not self.thread_id and result.get("thread_id"):
  98. self.thread_id = result.get("thread_id")
  99. async def _show_current_history(self):
  100. """显示当前会话的历史记录。"""
  101. if not self.thread_id:
  102. print("当前没有活跃的会话。请先开始对话。")
  103. return
  104. print(f"\n--- 对话历史: {self.thread_id} ---")
  105. history = await self.agent.get_conversation_history(self.thread_id)
  106. if not history:
  107. print("无法获取历史或历史为空。")
  108. return
  109. for msg in history:
  110. print(f"[{msg['type']}] {msg['content']}")
  111. print("--- 历史结束 ---")
  112. async def main():
  113. """主函数入口"""
  114. shell = None
  115. try:
  116. shell = await CustomAgentShell.create()
  117. await shell.start()
  118. except KeyboardInterrupt:
  119. logger.info("\n👋 检测到退出指令,正在清理资源...")
  120. except Exception as e:
  121. logger.error(f"❌ 程序发生严重错误: {e}", exc_info=True)
  122. finally:
  123. if shell:
  124. await shell.close()
  125. print("✅ 程序已成功关闭。")
  126. if __name__ == "__main__":
  127. try:
  128. asyncio.run(main())
  129. except KeyboardInterrupt:
  130. # 这个捕获是为了处理在 main 之外的 Ctrl+C
  131. print("\n👋 程序被强制退出。")