|
@@ -0,0 +1,402 @@
|
|
|
|
+import asyncio
|
|
|
|
+from typing import List, Dict, Any, Tuple
|
|
|
|
+from schema_tools.tools.base import BaseTool, ToolRegistry
|
|
|
|
+from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo
|
|
|
|
+
|
|
|
|
+@ToolRegistry.register("comment_generator")
|
|
|
|
+class CommentGeneratorTool(BaseTool):
|
|
|
|
+ """LLM注释生成工具"""
|
|
|
|
+
|
|
|
|
+ needs_llm = True
|
|
|
|
+ tool_name = "注释生成器"
|
|
|
|
+
|
|
|
|
+ def __init__(self, **kwargs):
|
|
|
|
+ super().__init__(**kwargs)
|
|
|
|
+ self.business_context = kwargs.get('business_context', '')
|
|
|
|
+ self.business_dictionary = self._load_business_dictionary()
|
|
|
|
+
|
|
|
|
+ async def execute(self, context: TableProcessingContext) -> ProcessingResult:
|
|
|
|
+ """执行注释生成"""
|
|
|
|
+ try:
|
|
|
|
+ table_metadata = context.table_metadata
|
|
|
|
+
|
|
|
|
+ # 生成表注释
|
|
|
|
+ table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
|
|
|
|
+
|
|
|
|
+ # 生成字段注释和枚举建议
|
|
|
|
+ field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
|
|
|
|
+
|
|
|
|
+ # 更新表元数据
|
|
|
|
+ if table_comment_result['success']:
|
|
|
|
+ table_metadata.generated_comment = table_comment_result['comment']
|
|
|
|
+ table_metadata.comment = table_comment_result['comment']
|
|
|
|
+
|
|
|
|
+ # 更新字段信息
|
|
|
|
+ enum_suggestions = []
|
|
|
|
+ for i, field in enumerate(table_metadata.fields):
|
|
|
|
+ if i < len(field_results) and field_results[i]['success']:
|
|
|
|
+ field.generated_comment = field_results[i]['comment']
|
|
|
|
+ field.comment = field_results[i]['comment']
|
|
|
|
+
|
|
|
|
+ # 处理枚举建议
|
|
|
|
+ if field_results[i].get('is_enum'):
|
|
|
|
+ field.is_enum = True
|
|
|
|
+ enum_suggestions.append({
|
|
|
|
+ 'field_name': field.name,
|
|
|
|
+ 'suggested_values': field_results[i].get('enum_values', []),
|
|
|
|
+ 'enum_description': field_results[i].get('enum_description', '')
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ # 验证枚举建议
|
|
|
|
+ if enum_suggestions:
|
|
|
|
+ validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
|
|
|
|
+
|
|
|
|
+ # 更新验证后的枚举信息
|
|
|
|
+ for enum_info in validated_enums:
|
|
|
|
+ field_name = enum_info['field_name']
|
|
|
|
+ for field in table_metadata.fields:
|
|
|
|
+ if field.name == field_name:
|
|
|
|
+ field.enum_values = enum_info['actual_values']
|
|
|
|
+ field.enum_description = enum_info['description']
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ return ProcessingResult(
|
|
|
|
+ success=True,
|
|
|
|
+ data={
|
|
|
|
+ 'table_comment_generated': table_comment_result['success'],
|
|
|
|
+ 'field_comments_generated': sum(1 for r in field_results if r['success']),
|
|
|
|
+ 'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
|
|
|
|
+ 'enum_suggestions': enum_suggestions
|
|
|
|
+ },
|
|
|
|
+ metadata={'tool': self.tool_name}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.exception(f"注释生成失败")
|
|
|
|
+ return ProcessingResult(
|
|
|
|
+ success=False,
|
|
|
|
+ error_message=f"注释生成失败: {str(e)}"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
|
|
|
|
+ """生成表注释"""
|
|
|
|
+ try:
|
|
|
|
+ prompt = self._build_table_comment_prompt(table_metadata, business_context)
|
|
|
|
+
|
|
|
|
+ # 调用LLM
|
|
|
|
+ response = await self._call_llm_with_retry(prompt)
|
|
|
|
+
|
|
|
|
+ # 解析响应
|
|
|
|
+ comment = self._extract_table_comment(response)
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ 'success': True,
|
|
|
|
+ 'comment': comment,
|
|
|
|
+ 'original_response': response
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.error(f"表注释生成失败: {e}")
|
|
|
|
+ return {
|
|
|
|
+ 'success': False,
|
|
|
|
+ 'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
|
|
|
|
+ 'error': str(e)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
|
|
|
|
+ """批量生成字段注释和枚举建议"""
|
|
|
|
+ try:
|
|
|
|
+ # 构建批量处理的提示词
|
|
|
|
+ prompt = self._build_field_batch_prompt(table_metadata, business_context)
|
|
|
|
+
|
|
|
|
+ # 调用LLM
|
|
|
|
+ response = await self._call_llm_with_retry(prompt)
|
|
|
|
+
|
|
|
|
+ # 解析批量响应
|
|
|
|
+ field_results = self._parse_field_batch_response(response, table_metadata.fields)
|
|
|
|
+
|
|
|
|
+ return field_results
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.error(f"字段注释批量生成失败: {e}")
|
|
|
|
+ # 返回默认结果
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ 'success': False,
|
|
|
|
+ 'comment': field.original_comment or field.name,
|
|
|
|
+ 'is_enum': False,
|
|
|
|
+ 'error': str(e)
|
|
|
|
+ }
|
|
|
|
+ for field in table_metadata.fields
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
|
|
|
|
+ """构建表注释生成提示词"""
|
|
|
|
+ # 准备字段信息摘要
|
|
|
|
+ fields_summary = []
|
|
|
|
+ for field in table_metadata.fields[:10]: # 只显示前10个字段避免过长
|
|
|
|
+ field_desc = f"- {field.name} ({field.type})"
|
|
|
|
+ if field.comment:
|
|
|
|
+ field_desc += f": {field.comment}"
|
|
|
|
+ fields_summary.append(field_desc)
|
|
|
|
+
|
|
|
|
+ # 准备样例数据摘要
|
|
|
|
+ sample_summary = ""
|
|
|
|
+ if table_metadata.sample_data:
|
|
|
|
+ sample_count = min(3, len(table_metadata.sample_data))
|
|
|
|
+ sample_summary = f"\n样例数据({sample_count}条):\n"
|
|
|
|
+ for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
|
|
|
|
+ sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
|
|
|
|
+ sample_summary += f"{i+1}. {sample_str}\n"
|
|
|
|
+
|
|
|
|
+ prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。
|
|
|
|
+
|
|
|
|
+业务背景: {business_context}
|
|
|
|
+{self.business_dictionary}
|
|
|
|
+
|
|
|
|
+表信息:
|
|
|
|
+- 表名: {table_metadata.table_name}
|
|
|
|
+- Schema: {table_metadata.schema_name}
|
|
|
|
+- 现有注释: {table_metadata.original_comment or "无"}
|
|
|
|
+- 字段数量: {len(table_metadata.fields)}
|
|
|
|
+- 数据行数: {table_metadata.row_count or "未知"}
|
|
|
|
+
|
|
|
|
+主要字段:
|
|
|
|
+{chr(10).join(fields_summary)}
|
|
|
|
+
|
|
|
|
+{sample_summary}
|
|
|
|
+
|
|
|
|
+请生成一个简洁、准确的中文表注释,要求:
|
|
|
|
+1. 如果现有注释是英文,请翻译为中文并改进
|
|
|
|
+2. 根据字段名称和样例数据推断表的业务用途
|
|
|
|
+3. 注释长度控制在50字以内
|
|
|
|
+4. 突出表的核心业务价值
|
|
|
|
+
|
|
|
|
+表注释:"""
|
|
|
|
+
|
|
|
|
+ return prompt
|
|
|
|
+
|
|
|
|
+ def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
|
|
|
|
+ """构建字段批量处理提示词"""
|
|
|
|
+ # 准备字段信息
|
|
|
|
+ fields_info = []
|
|
|
|
+ sample_values = {}
|
|
|
|
+
|
|
|
|
+ # 收集字段的样例值
|
|
|
|
+ for sample in table_metadata.sample_data[:5]:
|
|
|
|
+ for field_name, value in sample.items():
|
|
|
|
+ if field_name not in sample_values:
|
|
|
|
+ sample_values[field_name] = []
|
|
|
|
+ if value is not None and len(sample_values[field_name]) < 5:
|
|
|
|
+ sample_values[field_name].append(str(value))
|
|
|
|
+
|
|
|
|
+ # 构建字段信息列表
|
|
|
|
+ for field in table_metadata.fields:
|
|
|
|
+ field_info = f"{field.name} ({field.type})"
|
|
|
|
+ if field.original_comment:
|
|
|
|
+ field_info += f" - 原注释: {field.original_comment}"
|
|
|
|
+
|
|
|
|
+ # 添加样例值
|
|
|
|
+ if field.name in sample_values and sample_values[field.name]:
|
|
|
|
+ values_str = ", ".join(sample_values[field.name][:3])
|
|
|
|
+ field_info += f" - 样例值: {values_str}"
|
|
|
|
+
|
|
|
|
+ fields_info.append(field_info)
|
|
|
|
+
|
|
|
|
+ prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。
|
|
|
|
+
|
|
|
|
+业务背景: {business_context}
|
|
|
|
+{self.business_dictionary}
|
|
|
|
+
|
|
|
|
+表名: {table_metadata.schema_name}.{table_metadata.table_name}
|
|
|
|
+表注释: {table_metadata.comment or "无"}
|
|
|
|
+
|
|
|
|
+字段列表:
|
|
|
|
+{chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}
|
|
|
|
+
|
|
|
|
+请按以下JSON格式输出每个字段的分析结果:
|
|
|
|
+```json
|
|
|
|
+{{
|
|
|
|
+ "fields": [
|
|
|
|
+ {{
|
|
|
|
+ "name": "字段名",
|
|
|
|
+ "comment": "中文注释(简洁明确,15字以内)",
|
|
|
|
+ "is_enum": true/false,
|
|
|
|
+ "enum_values": ["值1", "值2", "值3"] (如果是枚举),
|
|
|
|
+ "enum_description": "枚举含义说明" (如果是枚举)
|
|
|
|
+ }}
|
|
|
|
+ ]
|
|
|
|
+}}
|
|
|
|
+```
|
|
|
|
+
|
|
|
|
+注释生成要求:
|
|
|
|
+1. 如果原注释是英文,翻译为中文并改进
|
|
|
|
+2. 根据字段名、类型和样例值推断字段含义
|
|
|
|
+3. 识别可能的枚举字段(如状态、类型、级别等)
|
|
|
|
+4. 枚举判断标准: VARCHAR类型 + 样例值重复度高 + 字段名暗示分类
|
|
|
|
+5. 注释要贴近{business_context}的业务场景
|
|
|
|
+
|
|
|
|
+请输出JSON格式的分析结果:"""
|
|
|
|
+
|
|
|
|
+ return prompt
|
|
|
|
+
|
|
|
|
+ async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
|
|
|
|
+ """带重试的LLM调用"""
|
|
|
|
+ from schema_tools.config import SCHEMA_TOOLS_CONFIG
|
|
|
|
+
|
|
|
|
+ for attempt in range(max_retries):
|
|
|
|
+ try:
|
|
|
|
+ # 使用vanna实例的chat_with_llm方法进行自由聊天
|
|
|
|
+ # 这是专门用于生成训练数据的方法,不会查询向量数据库
|
|
|
|
+ response = await asyncio.to_thread(
|
|
|
|
+ self.vn.chat_with_llm,
|
|
|
|
+ question=prompt,
|
|
|
|
+ system_prompt="你是一个专业的数据库文档专家,专门负责生成高质量的中文数据库表和字段注释。"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if response and response.strip():
|
|
|
|
+ return response.strip()
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("LLM返回空响应")
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
|
|
|
+ if attempt == max_retries - 1:
|
|
|
|
+ raise
|
|
|
|
+ await asyncio.sleep(1) # 等待1秒后重试
|
|
|
|
+
|
|
|
|
+ raise Exception("LLM调用达到最大重试次数")
|
|
|
|
+
|
|
|
|
+ def _extract_table_comment(self, llm_response: str) -> str:
|
|
|
|
+ """从LLM响应中提取表注释"""
|
|
|
|
+ # 简单的文本清理和提取逻辑
|
|
|
|
+ lines = llm_response.strip().split('\n')
|
|
|
|
+
|
|
|
|
+ # 查找包含实际注释的行
|
|
|
|
+ for line in lines:
|
|
|
|
+ line = line.strip()
|
|
|
|
+ if line and not line.startswith('#') and not line.startswith('*'):
|
|
|
|
+ # 移除可能的前缀
|
|
|
|
+ prefixes = ['表注释:', '注释:', '说明:', '表说明:']
|
|
|
|
+ for prefix in prefixes:
|
|
|
|
+ if line.startswith(prefix):
|
|
|
|
+ line = line[len(prefix):].strip()
|
|
|
|
+
|
|
|
|
+ if line:
|
|
|
|
+ return line[:200] # 限制长度
|
|
|
|
+
|
|
|
|
+ return llm_response.strip()[:200]
|
|
|
|
+
|
|
|
|
+ def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
|
|
|
|
+ """解析字段批量处理响应"""
|
|
|
|
+ import json
|
|
|
|
+ import re
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # 尝试提取JSON部分
|
|
|
|
+ json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
|
|
|
|
+ if json_match:
|
|
|
|
+ json_str = json_match.group(1)
|
|
|
|
+ else:
|
|
|
|
+ # 如果没有代码块,尝试直接解析
|
|
|
|
+ json_str = llm_response
|
|
|
|
+
|
|
|
|
+ # 解析JSON
|
|
|
|
+ parsed_data = json.loads(json_str)
|
|
|
|
+ field_data = parsed_data.get('fields', [])
|
|
|
|
+
|
|
|
|
+ # 映射到字段结果
|
|
|
|
+ results = []
|
|
|
|
+ for i, field in enumerate(fields):
|
|
|
|
+ if i < len(field_data):
|
|
|
|
+ data = field_data[i]
|
|
|
|
+ results.append({
|
|
|
|
+ 'success': True,
|
|
|
|
+ 'comment': data.get('comment', field.name),
|
|
|
|
+ 'is_enum': data.get('is_enum', False),
|
|
|
|
+ 'enum_values': data.get('enum_values', []),
|
|
|
|
+ 'enum_description': data.get('enum_description', '')
|
|
|
|
+ })
|
|
|
|
+ else:
|
|
|
|
+ # 默认结果
|
|
|
|
+ results.append({
|
|
|
|
+ 'success': False,
|
|
|
|
+ 'comment': field.original_comment or field.name,
|
|
|
|
+ 'is_enum': False
|
|
|
|
+ })
|
|
|
|
+
|
|
|
|
+ return results
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.error(f"解析字段批量响应失败: {e}")
|
|
|
|
+ # 返回默认结果
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ 'success': False,
|
|
|
|
+ 'comment': field.original_comment or field.name,
|
|
|
|
+ 'is_enum': False,
|
|
|
|
+ 'error': str(e)
|
|
|
|
+ }
|
|
|
|
+ for field in fields
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
|
|
|
|
+ """验证枚举建议"""
|
|
|
|
+ from schema_tools.tools.database_inspector import DatabaseInspectorTool
|
|
|
|
+ from schema_tools.config import SCHEMA_TOOLS_CONFIG
|
|
|
|
+
|
|
|
|
+ validated_enums = []
|
|
|
|
+ inspector = ToolRegistry.get_tool("database_inspector")
|
|
|
|
+ sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
|
|
|
|
+
|
|
|
|
+ for enum_info in enum_suggestions:
|
|
|
|
+ field_name = enum_info['field_name']
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # 查询字段的所有不同值
|
|
|
|
+ query = f"""
|
|
|
|
+ SELECT DISTINCT {field_name} as value, COUNT(*) as count
|
|
|
|
+ FROM {table_metadata.full_name}
|
|
|
|
+ WHERE {field_name} IS NOT NULL
|
|
|
|
+ GROUP BY {field_name}
|
|
|
|
+ ORDER BY count DESC
|
|
|
|
+ LIMIT {sample_limit}
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ async with inspector.connection_pool.acquire() as conn:
|
|
|
|
+ rows = await conn.fetch(query)
|
|
|
|
+
|
|
|
|
+ actual_values = [str(row['value']) for row in rows]
|
|
|
|
+
|
|
|
|
+ # 验证是否真的是枚举(不同值数量合理)
|
|
|
|
+ max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
|
|
|
|
+ if len(actual_values) <= max_enum_values:
|
|
|
|
+ validated_enums.append({
|
|
|
|
+ 'field_name': field_name,
|
|
|
|
+ 'actual_values': actual_values,
|
|
|
|
+ 'suggested_values': enum_info['suggested_values'],
|
|
|
|
+ 'description': enum_info['enum_description'],
|
|
|
|
+ 'value_counts': [(row['value'], row['count']) for row in rows]
|
|
|
|
+ })
|
|
|
|
+ self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
|
|
|
|
+ else:
|
|
|
|
+ self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
|
|
|
|
+
|
|
|
|
+ return validated_enums
|
|
|
|
+
|
|
|
|
+ def _load_business_dictionary(self) -> str:
|
|
|
|
+ """加载业务词典"""
|
|
|
|
+ try:
|
|
|
|
+ import os
|
|
|
|
+ dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
|
|
|
|
+ if os.path.exists(dict_file):
|
|
|
|
+ with open(dict_file, 'r', encoding='utf-8') as f:
|
|
|
|
+ content = f.read().strip()
|
|
|
|
+ return f"\n业务词典:\n{content}\n" if content else ""
|
|
|
|
+ return ""
|
|
|
|
+ except Exception as e:
|
|
|
|
+ self.logger.warning(f"加载业务词典失败: {e}")
|
|
|
|
+ return ""
|