comment_generator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import asyncio
  2. from typing import List, Dict, Any, Tuple
  3. from schema_tools.tools.base import BaseTool, ToolRegistry
  4. from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo
  5. @ToolRegistry.register("comment_generator")
  6. class CommentGeneratorTool(BaseTool):
  7. """LLM注释生成工具"""
  8. needs_llm = True
  9. tool_name = "注释生成器"
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. self.business_context = kwargs.get('business_context', '')
  13. self.business_dictionary = self._load_business_dictionary()
  14. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  15. """执行注释生成"""
  16. try:
  17. table_metadata = context.table_metadata
  18. # 生成表注释
  19. table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
  20. # 生成字段注释和枚举建议
  21. field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
  22. # 更新表元数据
  23. if table_comment_result['success']:
  24. table_metadata.generated_comment = table_comment_result['comment']
  25. table_metadata.comment = table_comment_result['comment']
  26. # 更新字段信息
  27. enum_suggestions = []
  28. for i, field in enumerate(table_metadata.fields):
  29. if i < len(field_results) and field_results[i]['success']:
  30. field.generated_comment = field_results[i]['comment']
  31. field.comment = field_results[i]['comment']
  32. # 处理枚举建议
  33. if field_results[i].get('is_enum'):
  34. field.is_enum = True
  35. enum_suggestions.append({
  36. 'field_name': field.name,
  37. 'suggested_values': field_results[i].get('enum_values', []),
  38. 'enum_description': field_results[i].get('enum_description', '')
  39. })
  40. # 验证枚举建议
  41. if enum_suggestions:
  42. validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
  43. # 更新验证后的枚举信息
  44. for enum_info in validated_enums:
  45. field_name = enum_info['field_name']
  46. for field in table_metadata.fields:
  47. if field.name == field_name:
  48. field.enum_values = enum_info['actual_values']
  49. field.enum_description = enum_info['description']
  50. break
  51. return ProcessingResult(
  52. success=True,
  53. data={
  54. 'table_comment_generated': table_comment_result['success'],
  55. 'field_comments_generated': sum(1 for r in field_results if r['success']),
  56. 'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
  57. 'enum_suggestions': enum_suggestions
  58. },
  59. metadata={'tool': self.tool_name}
  60. )
  61. except Exception as e:
  62. self.logger.exception(f"注释生成失败")
  63. return ProcessingResult(
  64. success=False,
  65. error_message=f"注释生成失败: {str(e)}"
  66. )
  67. async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
  68. """生成表注释"""
  69. try:
  70. prompt = self._build_table_comment_prompt(table_metadata, business_context)
  71. # 调用LLM
  72. response = await self._call_llm_with_retry(prompt)
  73. # 解析响应
  74. comment = self._extract_table_comment(response)
  75. return {
  76. 'success': True,
  77. 'comment': comment,
  78. 'original_response': response
  79. }
  80. except Exception as e:
  81. self.logger.error(f"表注释生成失败: {e}")
  82. return {
  83. 'success': False,
  84. 'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
  85. 'error': str(e)
  86. }
  87. async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
  88. """批量生成字段注释和枚举建议"""
  89. try:
  90. # 构建批量处理的提示词
  91. prompt = self._build_field_batch_prompt(table_metadata, business_context)
  92. # 调用LLM
  93. response = await self._call_llm_with_retry(prompt)
  94. # 解析批量响应
  95. field_results = self._parse_field_batch_response(response, table_metadata.fields)
  96. return field_results
  97. except Exception as e:
  98. self.logger.error(f"字段注释批量生成失败: {e}")
  99. # 返回默认结果
  100. return [
  101. {
  102. 'success': False,
  103. 'comment': field.original_comment or field.name,
  104. 'is_enum': False,
  105. 'error': str(e)
  106. }
  107. for field in table_metadata.fields
  108. ]
  109. def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
  110. """构建表注释生成提示词"""
  111. # 准备字段信息摘要
  112. fields_summary = []
  113. for field in table_metadata.fields[:10]: # 只显示前10个字段避免过长
  114. field_desc = f"- {field.name} ({field.type})"
  115. if field.comment:
  116. field_desc += f": {field.comment}"
  117. fields_summary.append(field_desc)
  118. # 准备样例数据摘要
  119. sample_summary = ""
  120. if table_metadata.sample_data:
  121. sample_count = min(3, len(table_metadata.sample_data))
  122. sample_summary = f"\n样例数据({sample_count}条):\n"
  123. for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
  124. sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
  125. sample_summary += f"{i+1}. {sample_str}\n"
  126. prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。
  127. 业务背景: {business_context}
  128. {self.business_dictionary}
  129. 表信息:
  130. - 表名: {table_metadata.table_name}
  131. - Schema: {table_metadata.schema_name}
  132. - 现有注释: {table_metadata.original_comment or "无"}
  133. - 字段数量: {len(table_metadata.fields)}
  134. - 数据行数: {table_metadata.row_count or "未知"}
  135. 主要字段:
  136. {chr(10).join(fields_summary)}
  137. {sample_summary}
  138. 请生成一个简洁、准确的中文表注释,要求:
  139. 1. 如果现有注释是英文,请翻译为中文并改进
  140. 2. 根据字段名称和样例数据推断表的业务用途
  141. 3. 注释长度控制在50字以内
  142. 4. 突出表的核心业务价值
  143. 表注释:"""
  144. return prompt
  145. def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
  146. """构建字段批量处理提示词"""
  147. # 准备字段信息
  148. fields_info = []
  149. sample_values = {}
  150. # 收集字段的样例值
  151. for sample in table_metadata.sample_data[:5]:
  152. for field_name, value in sample.items():
  153. if field_name not in sample_values:
  154. sample_values[field_name] = []
  155. if value is not None and len(sample_values[field_name]) < 5:
  156. sample_values[field_name].append(str(value))
  157. # 构建字段信息列表
  158. for field in table_metadata.fields:
  159. field_info = f"{field.name} ({field.type})"
  160. if field.original_comment:
  161. field_info += f" - 原注释: {field.original_comment}"
  162. # 添加样例值
  163. if field.name in sample_values and sample_values[field.name]:
  164. values_str = ", ".join(sample_values[field.name][:3])
  165. field_info += f" - 样例值: {values_str}"
  166. fields_info.append(field_info)
  167. prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。
  168. 业务背景: {business_context}
  169. {self.business_dictionary}
  170. 表名: {table_metadata.schema_name}.{table_metadata.table_name}
  171. 表注释: {table_metadata.comment or "无"}
  172. 字段列表:
  173. {chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}
  174. 请按以下JSON格式输出每个字段的分析结果:
  175. ```json
  176. {{
  177. "fields": [
  178. {{
  179. "name": "字段名",
  180. "comment": "中文注释(简洁明确,15字以内)",
  181. "is_enum": true/false,
  182. "enum_values": ["值1", "值2", "值3"] (如果是枚举),
  183. "enum_description": "枚举含义说明" (如果是枚举)
  184. }}
  185. ]
  186. }}
  187. ```
  188. 注释生成要求:
  189. 1. 如果原注释是英文,翻译为中文并改进
  190. 2. 根据字段名、类型和样例值推断字段含义
  191. 3. 识别可能的枚举字段(如状态、类型、级别等)
  192. 4. 枚举判断标准: VARCHAR类型 + 样例值重复度高 + 字段名暗示分类
  193. 5. 注释要贴近{business_context}的业务场景
  194. 请输出JSON格式的分析结果:"""
  195. return prompt
  196. async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
  197. """带重试的LLM调用"""
  198. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  199. for attempt in range(max_retries):
  200. try:
  201. # 使用vanna实例的chat_with_llm方法进行自由聊天
  202. # 这是专门用于生成训练数据的方法,不会查询向量数据库
  203. response = await asyncio.to_thread(
  204. self.vn.chat_with_llm,
  205. question=prompt,
  206. system_prompt="你是一个专业的数据库文档专家,专门负责生成高质量的中文数据库表和字段注释。"
  207. )
  208. if response and response.strip():
  209. return response.strip()
  210. else:
  211. raise ValueError("LLM返回空响应")
  212. except Exception as e:
  213. self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
  214. if attempt == max_retries - 1:
  215. raise
  216. await asyncio.sleep(1) # 等待1秒后重试
  217. raise Exception("LLM调用达到最大重试次数")
  218. def _extract_table_comment(self, llm_response: str) -> str:
  219. """从LLM响应中提取表注释"""
  220. # 简单的文本清理和提取逻辑
  221. lines = llm_response.strip().split('\n')
  222. # 查找包含实际注释的行
  223. for line in lines:
  224. line = line.strip()
  225. if line and not line.startswith('#') and not line.startswith('*'):
  226. # 移除可能的前缀
  227. prefixes = ['表注释:', '注释:', '说明:', '表说明:']
  228. for prefix in prefixes:
  229. if line.startswith(prefix):
  230. line = line[len(prefix):].strip()
  231. if line:
  232. return line[:200] # 限制长度
  233. return llm_response.strip()[:200]
  234. def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
  235. """解析字段批量处理响应"""
  236. import json
  237. import re
  238. try:
  239. # 尝试提取JSON部分
  240. json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
  241. if json_match:
  242. json_str = json_match.group(1)
  243. else:
  244. # 如果没有代码块,尝试直接解析
  245. json_str = llm_response
  246. # 解析JSON
  247. parsed_data = json.loads(json_str)
  248. field_data = parsed_data.get('fields', [])
  249. # 映射到字段结果
  250. results = []
  251. for i, field in enumerate(fields):
  252. if i < len(field_data):
  253. data = field_data[i]
  254. results.append({
  255. 'success': True,
  256. 'comment': data.get('comment', field.name),
  257. 'is_enum': data.get('is_enum', False),
  258. 'enum_values': data.get('enum_values', []),
  259. 'enum_description': data.get('enum_description', '')
  260. })
  261. else:
  262. # 默认结果
  263. results.append({
  264. 'success': False,
  265. 'comment': field.original_comment or field.name,
  266. 'is_enum': False
  267. })
  268. return results
  269. except Exception as e:
  270. self.logger.error(f"解析字段批量响应失败: {e}")
  271. # 返回默认结果
  272. return [
  273. {
  274. 'success': False,
  275. 'comment': field.original_comment or field.name,
  276. 'is_enum': False,
  277. 'error': str(e)
  278. }
  279. for field in fields
  280. ]
  281. async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
  282. """验证枚举建议"""
  283. from schema_tools.tools.database_inspector import DatabaseInspectorTool
  284. from schema_tools.config import SCHEMA_TOOLS_CONFIG
  285. validated_enums = []
  286. inspector = ToolRegistry.get_tool("database_inspector")
  287. sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
  288. for enum_info in enum_suggestions:
  289. field_name = enum_info['field_name']
  290. try:
  291. # 查询字段的所有不同值
  292. query = f"""
  293. SELECT DISTINCT {field_name} as value, COUNT(*) as count
  294. FROM {table_metadata.full_name}
  295. WHERE {field_name} IS NOT NULL
  296. GROUP BY {field_name}
  297. ORDER BY count DESC
  298. LIMIT {sample_limit}
  299. """
  300. async with inspector.connection_pool.acquire() as conn:
  301. rows = await conn.fetch(query)
  302. actual_values = [str(row['value']) for row in rows]
  303. # 验证是否真的是枚举(不同值数量合理)
  304. max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
  305. if len(actual_values) <= max_enum_values:
  306. validated_enums.append({
  307. 'field_name': field_name,
  308. 'actual_values': actual_values,
  309. 'suggested_values': enum_info['suggested_values'],
  310. 'description': enum_info['enum_description'],
  311. 'value_counts': [(row['value'], row['count']) for row in rows]
  312. })
  313. self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
  314. else:
  315. self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
  316. except Exception as e:
  317. self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
  318. return validated_enums
  319. def _load_business_dictionary(self) -> str:
  320. """加载业务词典"""
  321. try:
  322. import os
  323. dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
  324. if os.path.exists(dict_file):
  325. with open(dict_file, 'r', encoding='utf-8') as f:
  326. content = f.read().strip()
  327. return f"\n业务词典:\n{content}\n" if content else ""
  328. return ""
  329. except Exception as e:
  330. self.logger.warning(f"加载业务词典失败: {e}")
  331. return ""