load_prompts.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """
  2. 提示词加载器
  3. 用于从yaml文件中加载LLM提示词配置
  4. """
  5. import os
  6. import yaml
  7. from typing import Dict, Any
  8. from core.logging import get_vanna_logger
  9. class PromptLoader:
  10. """提示词加载器类"""
  11. def __init__(self, config_path: str = None):
  12. """
  13. 初始化提示词加载器
  14. Args:
  15. config_path: yaml配置文件路径,默认为当前目录下的llm_prompts.yaml
  16. """
  17. self.logger = get_vanna_logger("PromptLoader")
  18. if config_path is None:
  19. # 默认配置文件路径
  20. current_dir = os.path.dirname(os.path.abspath(__file__))
  21. config_path = os.path.join(current_dir, "llm_prompts.yaml")
  22. self.config_path = config_path
  23. self._prompts = None
  24. self._load_prompts()
  25. def _load_prompts(self):
  26. """从yaml文件加载提示词配置"""
  27. try:
  28. with open(self.config_path, 'r', encoding='utf-8') as file:
  29. self._prompts = yaml.safe_load(file)
  30. self.logger.debug(f"成功加载提示词配置: {self.config_path}")
  31. except FileNotFoundError:
  32. self.logger.error(f"提示词配置文件未找到: {self.config_path}")
  33. self._prompts = {}
  34. except yaml.YAMLError as e:
  35. self.logger.error(f"解析yaml配置文件失败: {e}")
  36. self._prompts = {}
  37. except Exception as e:
  38. self.logger.error(f"加载提示词配置时出现未知错误: {e}")
  39. self._prompts = {}
  40. def get_prompt(self, category: str, key: str, **kwargs) -> str:
  41. """
  42. 获取指定的提示词
  43. Args:
  44. category: 提示词类别 (如 'sql_generation', 'chart_generation' 等)
  45. key: 提示词键名 (如 'initial_prompt', 'response_guidelines' 等)
  46. **kwargs: 用于格式化提示词的变量
  47. Returns:
  48. str: 格式化后的提示词,如果找不到则返回空字符串
  49. """
  50. try:
  51. if category not in self._prompts:
  52. self.logger.warning(f"未找到提示词类别: {category}")
  53. return ""
  54. if key not in self._prompts[category]:
  55. self.logger.warning(f"未找到提示词键: {category}.{key}")
  56. return ""
  57. prompt_template = self._prompts[category][key]
  58. # 如果有格式化参数,进行格式化
  59. if kwargs:
  60. try:
  61. return prompt_template.format(**kwargs)
  62. except KeyError as e:
  63. self.logger.warning(f"提示词格式化失败,缺少参数: {e}")
  64. return prompt_template
  65. return prompt_template
  66. except Exception as e:
  67. self.logger.error(f"获取提示词时出现错误: {e}")
  68. return ""
  69. def get_sql_initial_prompt(self, dialect: str) -> str:
  70. """获取SQL生成的初始提示词"""
  71. return self.get_prompt("sql_generation", "initial_prompt", dialect=dialect)
  72. def get_sql_response_guidelines(self, dialect: str) -> str:
  73. """获取SQL生成的响应指南"""
  74. return self.get_prompt("sql_generation", "response_guidelines", dialect=dialect)
  75. def get_chart_instructions(self) -> str:
  76. """获取图表生成的中文指令"""
  77. return self.get_prompt("chart_generation", "chinese_chart_instructions")
  78. def get_chart_system_message(self, question: str = None, sql: str = None, df_metadata: str = None) -> str:
  79. """获取图表生成的系统消息"""
  80. # 构建SQL部分
  81. sql_part = f"数据来源SQL查询:\n{sql}" if sql else ""
  82. # 构建问题部分
  83. if question:
  84. question_text = f"用户问题:'{question}'\n\n以下是回答用户问题的pandas DataFrame数据:"
  85. else:
  86. question_text = "以下是一个pandas DataFrame数据:"
  87. return self.get_prompt(
  88. "chart_generation",
  89. "system_message_template",
  90. question=question_text,
  91. sql_part=sql_part,
  92. df_metadata=df_metadata or ""
  93. )
  94. def get_chart_user_message(self) -> str:
  95. """获取图表生成的用户消息"""
  96. chinese_instructions = self.get_chart_instructions()
  97. return self.get_prompt(
  98. "chart_generation",
  99. "user_message_template",
  100. chinese_chart_instructions=chinese_instructions
  101. )
  102. def get_question_generation_prompt(self) -> str:
  103. """获取根据SQL生成问题的提示词"""
  104. return self.get_prompt("question_generation", "system_prompt")
  105. def get_chat_default_prompt(self) -> str:
  106. """获取聊天对话的默认系统提示词"""
  107. return self.get_prompt("chat_with_llm", "default_system_prompt")
  108. def get_question_merge_prompt(self) -> str:
  109. """获取问题合并的系统提示词"""
  110. return self.get_prompt("question_merge", "system_prompt")
  111. def get_summary_system_message(self, question: str, df_markdown: str) -> str:
  112. """获取摘要生成的系统消息"""
  113. return self.get_prompt(
  114. "summary_generation",
  115. "system_message_template",
  116. question=question,
  117. df_markdown=df_markdown
  118. )
  119. def get_summary_user_instructions(self) -> str:
  120. """获取摘要生成的用户指令"""
  121. return self.get_prompt("summary_generation", "user_instructions")
  122. def reload_prompts(self):
  123. """重新加载提示词配置"""
  124. self.logger.info("重新加载提示词配置")
  125. self._load_prompts()
  126. # 全局提示词加载器实例
  127. _prompt_loader = None
  128. def get_prompt_loader() -> PromptLoader:
  129. """
  130. 获取全局提示词加载器实例(单例模式)
  131. Returns:
  132. PromptLoader: 提示词加载器实例
  133. """
  134. global _prompt_loader
  135. if _prompt_loader is None:
  136. _prompt_loader = PromptLoader()
  137. return _prompt_loader