md_analyzer.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import logging
  2. from pathlib import Path
  3. from typing import List, Dict, Any
  4. class MDFileAnalyzer:
  5. """MD文件分析器"""
  6. def __init__(self, output_dir: str):
  7. self.output_dir = Path(output_dir)
  8. self.logger = logging.getLogger("schema_tools.MDFileAnalyzer")
  9. async def read_all_md_files(self) -> str:
  10. """
  11. 读取所有MD文件的完整内容
  12. Returns:
  13. 所有MD文件内容的组合字符串
  14. """
  15. md_files = sorted(self.output_dir.glob("*_detail.md"))
  16. if not md_files:
  17. raise ValueError(f"在 {self.output_dir} 目录下未找到MD文件")
  18. all_contents = []
  19. all_contents.append(f"# 数据库表结构文档汇总\n")
  20. all_contents.append(f"共包含 {len(md_files)} 个表\n\n")
  21. for md_file in md_files:
  22. self.logger.info(f"读取MD文件: {md_file.name}")
  23. try:
  24. content = md_file.read_text(encoding='utf-8')
  25. # 添加分隔符,便于LLM区分不同表
  26. all_contents.append("=" * 80)
  27. all_contents.append(f"# 文件: {md_file.name}")
  28. all_contents.append("=" * 80)
  29. all_contents.append(content)
  30. all_contents.append("\n")
  31. except Exception as e:
  32. self.logger.error(f"读取文件 {md_file.name} 失败: {e}")
  33. raise
  34. combined_content = "\n".join(all_contents)
  35. # 检查内容大小(预估token数)
  36. estimated_tokens = len(combined_content) / 4 # 粗略估算
  37. if estimated_tokens > 100000: # 假设token限制
  38. self.logger.warning(f"MD内容可能过大,预估tokens: {estimated_tokens:.0f}")
  39. self.logger.info(f"成功读取 {len(md_files)} 个MD文件,总字符数: {len(combined_content)}")
  40. return combined_content
  41. def get_table_summaries(self) -> List[Dict[str, str]]:
  42. """
  43. 获取所有表的摘要信息
  44. Returns:
  45. 表摘要列表
  46. """
  47. md_files = sorted(self.output_dir.glob("*_detail.md"))
  48. summaries = []
  49. for md_file in md_files:
  50. try:
  51. content = md_file.read_text(encoding='utf-8')
  52. lines = content.split('\n')
  53. # 提取表名和描述(通常在前几行)
  54. table_name = ""
  55. description = ""
  56. for line in lines[:10]: # 只看前10行
  57. line = line.strip()
  58. if line.startswith("##"):
  59. # 提取表名
  60. table_info = line.replace("##", "").strip()
  61. if "(" in table_info:
  62. table_name = table_info.split("(")[0].strip()
  63. else:
  64. table_name = table_info
  65. elif table_name and line and not line.startswith("#"):
  66. # 第一行非标题文本作为描述
  67. description = line
  68. break
  69. if table_name:
  70. summaries.append({
  71. "file": md_file.name,
  72. "table_name": table_name,
  73. "description": description
  74. })
  75. except Exception as e:
  76. self.logger.warning(f"处理文件 {md_file.name} 时出错: {e}")
  77. return summaries