comment_generator.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import asyncio
  2. from typing import List, Dict, Any, Tuple
  3. from data_pipeline.tools.base import BaseTool, ToolRegistry
  4. from data_pipeline.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo
  5. @ToolRegistry.register("comment_generator")
  6. class CommentGeneratorTool(BaseTool):
  7. """LLM注释生成工具"""
  8. needs_llm = True
  9. tool_name = "注释生成器"
  10. def __init__(self, **kwargs):
  11. super().__init__(**kwargs)
  12. self.business_context = kwargs.get('business_context', '')
  13. self.db_connection = kwargs.get('db_connection') # 支持传入数据库连接字符串
  14. self.business_dictionary = self._load_business_dictionary()
  15. async def execute(self, context: TableProcessingContext) -> ProcessingResult:
  16. """执行注释生成"""
  17. try:
  18. table_metadata = context.table_metadata
  19. # 生成表注释
  20. table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
  21. # 生成字段注释和枚举建议
  22. field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
  23. # 更新表元数据
  24. if table_comment_result['success']:
  25. table_metadata.generated_comment = table_comment_result['comment']
  26. table_metadata.comment = table_comment_result['comment']
  27. # 更新字段信息
  28. enum_suggestions = []
  29. for i, field in enumerate(table_metadata.fields):
  30. if i < len(field_results) and field_results[i]['success']:
  31. field.generated_comment = field_results[i]['comment']
  32. field.comment = field_results[i]['comment']
  33. # 处理枚举建议
  34. if field_results[i].get('is_enum'):
  35. field.is_enum = True
  36. enum_suggestions.append({
  37. 'field_name': field.name,
  38. 'suggested_values': field_results[i].get('enum_values', []),
  39. 'enum_description': field_results[i].get('enum_description', '')
  40. })
  41. # 验证枚举建议
  42. if enum_suggestions:
  43. validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
  44. # 更新验证后的枚举信息
  45. for enum_info in validated_enums:
  46. field_name = enum_info['field_name']
  47. for field in table_metadata.fields:
  48. if field.name == field_name:
  49. field.enum_values = enum_info['actual_values']
  50. field.enum_description = enum_info['description']
  51. break
  52. return ProcessingResult(
  53. success=True,
  54. data={
  55. 'table_comment_generated': table_comment_result['success'],
  56. 'field_comments_generated': sum(1 for r in field_results if r['success']),
  57. 'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
  58. 'enum_suggestions': enum_suggestions
  59. },
  60. metadata={'tool': self.tool_name}
  61. )
  62. except Exception as e:
  63. self.logger.exception(f"注释生成失败")
  64. return ProcessingResult(
  65. success=False,
  66. error_message=f"注释生成失败: {str(e)}"
  67. )
  68. async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
  69. """生成表注释"""
  70. try:
  71. prompt = self._build_table_comment_prompt(table_metadata, business_context)
  72. # 调用LLM
  73. response = await self._call_llm_with_retry(prompt)
  74. # 解析响应
  75. comment = self._extract_table_comment(response)
  76. return {
  77. 'success': True,
  78. 'comment': comment,
  79. 'original_response': response
  80. }
  81. except Exception as e:
  82. self.logger.error(f"表注释生成失败: {e}")
  83. return {
  84. 'success': False,
  85. 'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
  86. 'error': str(e)
  87. }
  88. async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
  89. """批量生成字段注释和枚举建议"""
  90. try:
  91. # 构建批量处理的提示词
  92. prompt = self._build_field_batch_prompt(table_metadata, business_context)
  93. # 调用LLM
  94. response = await self._call_llm_with_retry(prompt)
  95. # 解析批量响应
  96. field_results = self._parse_field_batch_response(response, table_metadata.fields)
  97. return field_results
  98. except Exception as e:
  99. self.logger.error(f"字段注释批量生成失败: {e}")
  100. # 返回默认结果
  101. return [
  102. {
  103. 'success': False,
  104. 'comment': field.original_comment or field.name,
  105. 'is_enum': False,
  106. 'error': str(e)
  107. }
  108. for field in table_metadata.fields
  109. ]
  110. def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
  111. """构建表注释生成提示词"""
  112. # 准备字段信息摘要
  113. fields_summary = []
  114. for field in table_metadata.fields[:10]: # 只显示前10个字段避免过长
  115. field_desc = f"- {field.name} ({field.type})"
  116. if field.comment:
  117. field_desc += f": {field.comment}"
  118. fields_summary.append(field_desc)
  119. # 准备样例数据摘要
  120. sample_summary = ""
  121. if table_metadata.sample_data:
  122. sample_count = min(3, len(table_metadata.sample_data))
  123. sample_summary = f"\n样例数据({sample_count}条):\n"
  124. for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
  125. sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
  126. sample_summary += f"{i+1}. {sample_str}\n"
  127. prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。
  128. 业务背景: {business_context}
  129. {self.business_dictionary}
  130. 表信息:
  131. - 表名: {table_metadata.table_name}
  132. - Schema: {table_metadata.schema_name}
  133. - 现有注释: {table_metadata.original_comment or "无"}
  134. - 字段数量: {len(table_metadata.fields)}
  135. - 数据行数: {table_metadata.row_count or "未知"}
  136. 主要字段:
  137. {chr(10).join(fields_summary)}
  138. {sample_summary}
  139. 请生成一个简洁、准确的中文表注释,要求:
  140. 1. 如果现有注释是英文,请翻译为中文并改进
  141. 2. 根据字段名称和样例数据推断表的业务用途
  142. 3. 注释长度控制在50字以内
  143. 4. 突出表的核心业务价值
  144. 表注释:"""
  145. return prompt
  146. def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
  147. """构建字段批量处理提示词"""
  148. # 准备字段信息
  149. fields_info = []
  150. sample_values = {}
  151. # 收集字段的样例值
  152. for sample in table_metadata.sample_data[:5]:
  153. for field_name, value in sample.items():
  154. if field_name not in sample_values:
  155. sample_values[field_name] = []
  156. if value is not None and len(sample_values[field_name]) < 5:
  157. sample_values[field_name].append(str(value))
  158. # 构建字段信息列表
  159. for field in table_metadata.fields:
  160. field_info = f"{field.name} ({field.type})"
  161. if field.original_comment:
  162. field_info += f" - 原注释: {field.original_comment}"
  163. # 添加样例值
  164. if field.name in sample_values and sample_values[field.name]:
  165. values_str = ", ".join(sample_values[field.name][:3])
  166. field_info += f" - 样例值: {values_str}"
  167. fields_info.append(field_info)
  168. prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。
  169. 业务背景: {business_context}
  170. {self.business_dictionary}
  171. 表名: {table_metadata.schema_name}.{table_metadata.table_name}
  172. 表注释: {table_metadata.comment or "无"}
  173. 字段列表:
  174. {chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}
  175. 请按以下JSON格式输出每个字段的分析结果:
  176. ```json
  177. {{
  178. "fields": [
  179. {{
  180. "name": "字段名",
  181. "comment": "中文注释(简洁明确,15字以内)",
  182. "is_enum": true/false,
  183. "enum_values": ["值1", "值2", "值3"] (如果是枚举),
  184. "enum_description": "枚举含义说明" (如果是枚举)
  185. }}
  186. ]
  187. }}
  188. ```
  189. 注释生成要求:
  190. 1. 如果原注释是英文,翻译为中文并改进
  191. 2. 根据字段名、类型和样例值推断字段含义
  192. 3. 识别可能的枚举字段(如状态、类型、级别等)
  193. 4. 枚举判断标准: VARCHAR类型 + 样例值重复度高 + 字段名暗示分类
  194. 5. 注释要贴近{business_context}的业务场景
  195. 请输出JSON格式的分析结果:"""
  196. return prompt
  197. async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
  198. """带重试的LLM调用"""
  199. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  200. for attempt in range(max_retries):
  201. try:
  202. # 使用vanna实例的chat_with_llm方法进行自由聊天
  203. # 这是专门用于生成训练数据的方法,不会查询向量数据库
  204. response = await asyncio.to_thread(
  205. self.vn.chat_with_llm,
  206. question=prompt,
  207. system_prompt="你是一个专业的数据库文档专家,专门负责生成高质量的中文数据库表和字段注释。"
  208. )
  209. if response and response.strip():
  210. return response.strip()
  211. else:
  212. raise ValueError("LLM返回空响应")
  213. except Exception as e:
  214. self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
  215. if attempt == max_retries - 1:
  216. raise
  217. await asyncio.sleep(1) # 等待1秒后重试
  218. raise Exception("LLM调用达到最大重试次数")
  219. def _extract_table_comment(self, llm_response: str) -> str:
  220. """从LLM响应中提取表注释"""
  221. # 简单的文本清理和提取逻辑
  222. lines = llm_response.strip().split('\n')
  223. # 查找包含实际注释的行
  224. for line in lines:
  225. line = line.strip()
  226. if line and not line.startswith('#') and not line.startswith('*'):
  227. # 移除可能的前缀
  228. prefixes = ['表注释:', '注释:', '说明:', '表说明:']
  229. for prefix in prefixes:
  230. if line.startswith(prefix):
  231. line = line[len(prefix):].strip()
  232. if line:
  233. return line[:200] # 限制长度
  234. return llm_response.strip()[:200]
  235. def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
  236. """解析字段批量处理响应"""
  237. import json
  238. import re
  239. try:
  240. # 尝试提取JSON部分
  241. json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
  242. if json_match:
  243. json_str = json_match.group(1)
  244. else:
  245. # 如果没有代码块,尝试直接解析
  246. json_str = llm_response
  247. # 解析JSON
  248. parsed_data = json.loads(json_str)
  249. field_data = parsed_data.get('fields', [])
  250. # 映射到字段结果
  251. results = []
  252. for i, field in enumerate(fields):
  253. if i < len(field_data):
  254. data = field_data[i]
  255. results.append({
  256. 'success': True,
  257. 'comment': data.get('comment', field.name),
  258. 'is_enum': data.get('is_enum', False),
  259. 'enum_values': data.get('enum_values', []),
  260. 'enum_description': data.get('enum_description', '')
  261. })
  262. else:
  263. # 默认结果
  264. results.append({
  265. 'success': False,
  266. 'comment': field.original_comment or field.name,
  267. 'is_enum': False
  268. })
  269. return results
  270. except Exception as e:
  271. self.logger.error(f"解析字段批量响应失败: {e}")
  272. # 返回默认结果
  273. return [
  274. {
  275. 'success': False,
  276. 'comment': field.original_comment or field.name,
  277. 'is_enum': False,
  278. 'error': str(e)
  279. }
  280. for field in fields
  281. ]
  282. async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
  283. """验证枚举建议"""
  284. import asyncpg
  285. from data_pipeline.config import SCHEMA_TOOLS_CONFIG
  286. validated_enums = []
  287. sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
  288. # 获取数据库连接字符串 - 优先使用传入的连接字符串
  289. db_connection = self.db_connection
  290. # 如果没有传入连接字符串,尝试从vanna实例获取
  291. if not db_connection:
  292. if hasattr(self.vn, 'connection_string'):
  293. db_connection = self.vn.connection_string
  294. elif hasattr(self.vn, '_connection_string'):
  295. db_connection = self.vn._connection_string
  296. if not db_connection:
  297. self.logger.warning("无法获取数据库连接字符串,跳过枚举验证")
  298. return validated_enums
  299. for enum_info in enum_suggestions:
  300. field_name = enum_info['field_name']
  301. try:
  302. # 查询字段的所有不同值
  303. query = f"""
  304. SELECT DISTINCT {field_name} as value, COUNT(*) as count
  305. FROM {table_metadata.full_name}
  306. WHERE {field_name} IS NOT NULL
  307. GROUP BY {field_name}
  308. ORDER BY count DESC
  309. LIMIT {sample_limit}
  310. """
  311. conn = await asyncpg.connect(db_connection)
  312. try:
  313. rows = await conn.fetch(query)
  314. actual_values = [str(row['value']) for row in rows]
  315. # 验证是否真的是枚举(不同值数量合理)
  316. max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
  317. if len(actual_values) <= max_enum_values:
  318. validated_enums.append({
  319. 'field_name': field_name,
  320. 'actual_values': actual_values,
  321. 'suggested_values': enum_info['suggested_values'],
  322. 'description': enum_info['enum_description'],
  323. 'value_counts': [(row['value'], row['count']) for row in rows]
  324. })
  325. self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
  326. else:
  327. self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
  328. finally:
  329. await conn.close()
  330. except Exception as e:
  331. self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
  332. return validated_enums
  333. def _load_business_dictionary(self) -> str:
  334. """加载业务词典"""
  335. try:
  336. import os
  337. dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
  338. if os.path.exists(dict_file):
  339. with open(dict_file, 'r', encoding='utf-8') as f:
  340. content = f.read().strip()
  341. return f"\n业务词典:\n{content}\n" if content else ""
  342. return ""
  343. except Exception as e:
  344. self.logger.warning(f"加载业务词典失败: {e}")
  345. return ""