shell.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """
  2. 重构后的 CustomReactAgent 的交互式命令行客户端
  3. """
  4. # from __future__ import annotations
  5. import asyncio
  6. import logging
  7. import sys
  8. import os
  9. import json
  10. from typing import List, Dict, Any
  11. # 将当前目录和项目根目录添加到 sys.path
  12. CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
  13. PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, '..', '..'))
  14. sys.path.insert(0, CURRENT_DIR) # 当前目录优先
  15. sys.path.insert(1, PROJECT_ROOT) # 项目根目录
  16. # 导入 Agent 和配置(简化版本)
  17. from agent import CustomReactAgent
  18. import config
  19. # 配置日志
  20. logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT)
  21. logger = logging.getLogger(__name__)
  22. class CustomAgentShell:
  23. """新 Agent 的交互式 Shell 客户端"""
  24. def __init__(self, agent: CustomReactAgent):
  25. """私有构造函数,请使用 create() 类方法。"""
  26. self.agent = agent
  27. self.user_id: str = config.DEFAULT_USER_ID
  28. self.thread_id: str | None = None
  29. self.recent_conversations: List[Dict[str, Any]] = [] # 存储最近的对话列表
  30. @classmethod
  31. async def create(cls):
  32. """异步工厂方法,创建 Shell 实例。"""
  33. agent = await CustomReactAgent.create()
  34. return cls(agent)
  35. async def close(self):
  36. """关闭 Agent 资源。"""
  37. if self.agent:
  38. await self.agent.close()
  39. async def _fetch_recent_conversations(self, user_id: str, limit: int = 5) -> List[Dict[str, Any]]:
  40. """获取最近的对话列表"""
  41. try:
  42. logger.info(f"🔍 获取用户 {user_id} 的最近 {limit} 次对话...")
  43. conversations = await self.agent.get_user_recent_conversations(user_id, limit)
  44. logger.info(f"✅ 成功获取 {len(conversations)} 个对话")
  45. return conversations
  46. except Exception as e:
  47. logger.error(f"❌ 获取对话列表失败: {e}")
  48. print(f"⚠️ 获取历史对话失败: {e}")
  49. print(" 将直接开始新对话...")
  50. return []
  51. def _display_conversation_list(self, conversations: List[Dict[str, Any]]) -> None:
  52. """显示对话列表"""
  53. if not conversations:
  54. print("📭 暂无历史对话,将开始新对话。")
  55. return
  56. print("\n📋 最近的对话记录:")
  57. print("-" * 60)
  58. for i, conv in enumerate(conversations, 1):
  59. thread_id = conv.get('thread_id', '')
  60. formatted_time = conv.get('formatted_time', '')
  61. preview = conv.get('conversation_preview', '无预览')
  62. message_count = conv.get('message_count', 0)
  63. print(f"[{i}] {formatted_time} - {preview}")
  64. print(f" Thread ID: {thread_id} | 消息数: {message_count}")
  65. print()
  66. print("💡 选择方式:")
  67. print(" - 输入序号 (1-5): 选择对应的对话")
  68. print(" - 输入 Thread ID: 直接指定对话")
  69. print(" - 输入日期 (YYYY-MM-DD): 选择当天最新对话")
  70. print(" - 输入 'new': 开始新对话")
  71. print(" - 直接输入问题: 开始新对话")
  72. print("-" * 60)
  73. def _parse_conversation_selection(self, user_input: str) -> Dict[str, Any]:
  74. """解析用户的对话选择"""
  75. user_input = user_input.strip()
  76. # 检查是否是数字序号 (1-5)
  77. if user_input.isdigit():
  78. index = int(user_input)
  79. if 1 <= index <= len(self.recent_conversations):
  80. selected_conv = self.recent_conversations[index - 1]
  81. return {
  82. "type": "select_by_index",
  83. "thread_id": selected_conv["thread_id"],
  84. "preview": selected_conv["conversation_preview"]
  85. }
  86. else:
  87. return {"type": "invalid_index", "message": f"序号 {index} 无效,请输入 1-{len(self.recent_conversations)}"}
  88. # 检查是否是 Thread ID 格式 (包含冒号)
  89. if ':' in user_input and len(user_input.split(':')) == 2:
  90. user_part, timestamp_part = user_input.split(':')
  91. # 简单验证格式
  92. if user_part == self.user_id and timestamp_part.isdigit():
  93. # 检查该Thread ID是否存在于历史对话中
  94. for conv in self.recent_conversations:
  95. if conv["thread_id"] == user_input:
  96. return {
  97. "type": "select_by_thread_id",
  98. "thread_id": user_input,
  99. "preview": conv["conversation_preview"]
  100. }
  101. return {"type": "thread_not_found", "message": f"Thread ID {user_input} 不存在于最近的对话中"}
  102. # 检查是否是日期格式 (YYYY-MM-DD)
  103. import re
  104. date_pattern = r'^\d{4}-\d{2}-\d{2}$'
  105. if re.match(date_pattern, user_input):
  106. # 查找该日期的最新对话
  107. target_date = user_input.replace('-', '') # 转换为 YYYYMMDD 格式
  108. for conv in self.recent_conversations:
  109. timestamp = conv.get('timestamp', '')
  110. if timestamp.startswith(target_date):
  111. return {
  112. "type": "select_by_date",
  113. "thread_id": conv["thread_id"],
  114. "preview": f"日期 {user_input} 的对话: {conv['conversation_preview']}"
  115. }
  116. return {"type": "no_date_match", "message": f"未找到 {user_input} 的对话"}
  117. # 检查是否是 'new' 命令
  118. if user_input.lower() == 'new':
  119. return {"type": "new_conversation"}
  120. # 其他情况当作新问题处理
  121. return {"type": "new_question", "question": user_input}
  122. async def start(self):
  123. """启动 Shell 界面。"""
  124. print("\n🚀 Custom React Agent Shell (StateGraph Version)")
  125. print("=" * 50)
  126. # 获取用户ID
  127. user_input = input(f"请输入您的用户ID (默认: {self.user_id}): ").strip()
  128. if user_input:
  129. self.user_id = user_input
  130. print(f"👤 当前用户: {self.user_id}")
  131. # 获取并显示最近的对话列表
  132. print("\n🔍 正在获取历史对话...")
  133. self.recent_conversations = await self._fetch_recent_conversations(self.user_id, 5)
  134. self._display_conversation_list(self.recent_conversations)
  135. print("\n💬 开始对话 (输入 'exit' 或 'quit' 退出)")
  136. print("-" * 50)
  137. await self._chat_loop()
  138. async def _chat_loop(self):
  139. """主要的聊天循环。"""
  140. while True:
  141. user_input = input(f"👤 [{self.user_id[:8]}]> ").strip()
  142. if not user_input:
  143. continue
  144. if user_input.lower() in ['quit', 'exit']:
  145. raise KeyboardInterrupt # 优雅退出
  146. if user_input.lower() == 'new':
  147. self.thread_id = None
  148. print("🆕 已开始新会话。")
  149. continue
  150. if user_input.lower() == 'history':
  151. await self._show_current_history()
  152. continue
  153. # 如果还没有选择对话,且有历史对话,则处理对话选择
  154. if self.thread_id is None and self.recent_conversations:
  155. selection = self._parse_conversation_selection(user_input)
  156. if selection["type"] == "select_by_index":
  157. self.thread_id = selection["thread_id"]
  158. print(f"📖 已选择对话: {selection['preview']}")
  159. print(f"💬 Thread ID: {self.thread_id}")
  160. print("现在可以在此对话中继续聊天...\n")
  161. continue
  162. elif selection["type"] == "select_by_thread_id":
  163. self.thread_id = selection["thread_id"]
  164. print(f"📖 已选择对话: {selection['preview']}")
  165. print("现在可以在此对话中继续聊天...\n")
  166. continue
  167. elif selection["type"] == "select_by_date":
  168. self.thread_id = selection["thread_id"]
  169. print(f"📖 已选择对话: {selection['preview']}")
  170. print("现在可以在此对话中继续聊天...\n")
  171. continue
  172. elif selection["type"] == "new_conversation":
  173. self.thread_id = None
  174. print("🆕 已开始新会话。")
  175. continue
  176. elif selection["type"] == "new_question":
  177. # 当作新问题处理,继续下面的正常对话流程
  178. user_input = selection["question"]
  179. self.thread_id = None
  180. print("🆕 开始新对话...")
  181. elif selection["type"] in ["invalid_index", "no_date_match", "thread_not_found"]:
  182. print(f"❌ {selection['message']}")
  183. continue
  184. # 正常对话流程
  185. print("🤖 Agent 正在思考...")
  186. result = await self.agent.chat(user_input, self.user_id, self.thread_id)
  187. if result.get("success"):
  188. answer = result.get('answer', '')
  189. # 注释掉 [Formatted Output] 清理逻辑 - 源头已不生成前缀
  190. # if answer.startswith("[Formatted Output]\n"):
  191. # answer = answer.replace("[Formatted Output]\n", "")
  192. print(f"🤖 Agent: {answer}")
  193. # 如果包含 SQL 数据,也显示出来
  194. if 'sql_data' in result:
  195. print(f"📊 SQL 查询结果: {result['sql_data']}")
  196. # 更新 thread_id 以便在同一会话中继续
  197. self.thread_id = result.get("thread_id")
  198. else:
  199. error_msg = result.get('error', '未知错误')
  200. print(f"❌ 发生错误: {error_msg}")
  201. # 提供针对性的建议
  202. if "Connection error" in error_msg or "网络" in error_msg:
  203. print("💡 建议:")
  204. print(" - 检查网络连接是否正常")
  205. print(" - 稍后重试该问题")
  206. print(" - 如果问题持续,可以尝试重新启动程序")
  207. elif "timeout" in error_msg.lower():
  208. print("💡 建议:")
  209. print(" - 当前网络较慢,建议稍后重试")
  210. print(" - 尝试简化问题复杂度")
  211. else:
  212. print("💡 建议:")
  213. print(" - 请检查问题格式是否正确")
  214. print(" - 尝试重新描述您的问题")
  215. # 保持thread_id,用户可以继续对话
  216. if not self.thread_id and result.get("thread_id"):
  217. self.thread_id = result.get("thread_id")
  218. async def _show_current_history(self):
  219. """显示当前会话的历史记录。"""
  220. if not self.thread_id:
  221. print("当前没有活跃的会话。请先开始对话。")
  222. return
  223. print(f"\n--- 对话历史: {self.thread_id} ---")
  224. history = await self.agent.get_conversation_history(self.thread_id)
  225. if not history:
  226. print("无法获取历史或历史为空。")
  227. return
  228. for msg in history:
  229. print(f"[{msg['type']}] {msg['content']}")
  230. print("--- 历史结束 ---")
  231. async def main():
  232. """主函数入口"""
  233. shell = None
  234. try:
  235. shell = await CustomAgentShell.create()
  236. await shell.start()
  237. except KeyboardInterrupt:
  238. logger.info("\n👋 检测到退出指令,正在清理资源...")
  239. except Exception as e:
  240. logger.error(f"❌ 程序发生严重错误: {e}", exc_info=True)
  241. finally:
  242. if shell:
  243. await shell.close()
  244. print("✅ 程序已成功关闭。")
  245. if __name__ == "__main__":
  246. try:
  247. asyncio.run(main())
  248. except KeyboardInterrupt:
  249. # 这个捕获是为了处理在 main 之外的 Ctrl+C
  250. print("\n👋 程序被强制退出。")