123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402 |
- import asyncio
- from typing import List, Dict, Any, Tuple
- from schema_tools.tools.base import BaseTool, ToolRegistry
- from schema_tools.utils.data_structures import ProcessingResult, TableProcessingContext, FieldInfo
- @ToolRegistry.register("comment_generator")
- class CommentGeneratorTool(BaseTool):
- """LLM注释生成工具"""
-
- needs_llm = True
- tool_name = "注释生成器"
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.business_context = kwargs.get('business_context', '')
- self.business_dictionary = self._load_business_dictionary()
-
- async def execute(self, context: TableProcessingContext) -> ProcessingResult:
- """执行注释生成"""
- try:
- table_metadata = context.table_metadata
-
- # 生成表注释
- table_comment_result = await self._generate_table_comment(table_metadata, context.business_context)
-
- # 生成字段注释和枚举建议
- field_results = await self._generate_field_comments_and_enums(table_metadata, context.business_context)
-
- # 更新表元数据
- if table_comment_result['success']:
- table_metadata.generated_comment = table_comment_result['comment']
- table_metadata.comment = table_comment_result['comment']
-
- # 更新字段信息
- enum_suggestions = []
- for i, field in enumerate(table_metadata.fields):
- if i < len(field_results) and field_results[i]['success']:
- field.generated_comment = field_results[i]['comment']
- field.comment = field_results[i]['comment']
-
- # 处理枚举建议
- if field_results[i].get('is_enum'):
- field.is_enum = True
- enum_suggestions.append({
- 'field_name': field.name,
- 'suggested_values': field_results[i].get('enum_values', []),
- 'enum_description': field_results[i].get('enum_description', '')
- })
-
- # 验证枚举建议
- if enum_suggestions:
- validated_enums = await self._validate_enum_suggestions(table_metadata, enum_suggestions)
-
- # 更新验证后的枚举信息
- for enum_info in validated_enums:
- field_name = enum_info['field_name']
- for field in table_metadata.fields:
- if field.name == field_name:
- field.enum_values = enum_info['actual_values']
- field.enum_description = enum_info['description']
- break
-
- return ProcessingResult(
- success=True,
- data={
- 'table_comment_generated': table_comment_result['success'],
- 'field_comments_generated': sum(1 for r in field_results if r['success']),
- 'enum_fields_detected': len([f for f in table_metadata.fields if f.is_enum]),
- 'enum_suggestions': enum_suggestions
- },
- metadata={'tool': self.tool_name}
- )
-
- except Exception as e:
- self.logger.exception(f"注释生成失败")
- return ProcessingResult(
- success=False,
- error_message=f"注释生成失败: {str(e)}"
- )
-
- async def _generate_table_comment(self, table_metadata, business_context: str) -> Dict[str, Any]:
- """生成表注释"""
- try:
- prompt = self._build_table_comment_prompt(table_metadata, business_context)
-
- # 调用LLM
- response = await self._call_llm_with_retry(prompt)
-
- # 解析响应
- comment = self._extract_table_comment(response)
-
- return {
- 'success': True,
- 'comment': comment,
- 'original_response': response
- }
-
- except Exception as e:
- self.logger.error(f"表注释生成失败: {e}")
- return {
- 'success': False,
- 'comment': table_metadata.original_comment or f"{table_metadata.table_name}表",
- 'error': str(e)
- }
-
- async def _generate_field_comments_and_enums(self, table_metadata, business_context: str) -> List[Dict[str, Any]]:
- """批量生成字段注释和枚举建议"""
- try:
- # 构建批量处理的提示词
- prompt = self._build_field_batch_prompt(table_metadata, business_context)
-
- # 调用LLM
- response = await self._call_llm_with_retry(prompt)
-
- # 解析批量响应
- field_results = self._parse_field_batch_response(response, table_metadata.fields)
-
- return field_results
-
- except Exception as e:
- self.logger.error(f"字段注释批量生成失败: {e}")
- # 返回默认结果
- return [
- {
- 'success': False,
- 'comment': field.original_comment or field.name,
- 'is_enum': False,
- 'error': str(e)
- }
- for field in table_metadata.fields
- ]
-
- def _build_table_comment_prompt(self, table_metadata, business_context: str) -> str:
- """构建表注释生成提示词"""
- # 准备字段信息摘要
- fields_summary = []
- for field in table_metadata.fields[:10]: # 只显示前10个字段避免过长
- field_desc = f"- {field.name} ({field.type})"
- if field.comment:
- field_desc += f": {field.comment}"
- fields_summary.append(field_desc)
-
- # 准备样例数据摘要
- sample_summary = ""
- if table_metadata.sample_data:
- sample_count = min(3, len(table_metadata.sample_data))
- sample_summary = f"\n样例数据({sample_count}条):\n"
- for i, sample in enumerate(table_metadata.sample_data[:sample_count]):
- sample_str = ", ".join([f"{k}={v}" for k, v in list(sample.items())[:5]])
- sample_summary += f"{i+1}. {sample_str}\n"
-
- prompt = f"""你是一个数据库文档专家。请根据以下信息为数据库表生成简洁、准确的中文注释。
- 业务背景: {business_context}
- {self.business_dictionary}
- 表信息:
- - 表名: {table_metadata.table_name}
- - Schema: {table_metadata.schema_name}
- - 现有注释: {table_metadata.original_comment or "无"}
- - 字段数量: {len(table_metadata.fields)}
- - 数据行数: {table_metadata.row_count or "未知"}
- 主要字段:
- {chr(10).join(fields_summary)}
- {sample_summary}
- 请生成一个简洁、准确的中文表注释,要求:
- 1. 如果现有注释是英文,请翻译为中文并改进
- 2. 根据字段名称和样例数据推断表的业务用途
- 3. 注释长度控制在50字以内
- 4. 突出表的核心业务价值
- 表注释:"""
-
- return prompt
-
- def _build_field_batch_prompt(self, table_metadata, business_context: str) -> str:
- """构建字段批量处理提示词"""
- # 准备字段信息
- fields_info = []
- sample_values = {}
-
- # 收集字段的样例值
- for sample in table_metadata.sample_data[:5]:
- for field_name, value in sample.items():
- if field_name not in sample_values:
- sample_values[field_name] = []
- if value is not None and len(sample_values[field_name]) < 5:
- sample_values[field_name].append(str(value))
-
- # 构建字段信息列表
- for field in table_metadata.fields:
- field_info = f"{field.name} ({field.type})"
- if field.original_comment:
- field_info += f" - 原注释: {field.original_comment}"
-
- # 添加样例值
- if field.name in sample_values and sample_values[field.name]:
- values_str = ", ".join(sample_values[field.name][:3])
- field_info += f" - 样例值: {values_str}"
-
- fields_info.append(field_info)
-
- prompt = f"""你是一个数据库文档专家。请为以下表的所有字段生成中文注释,并识别可能的枚举字段。
- 业务背景: {business_context}
- {self.business_dictionary}
- 表名: {table_metadata.schema_name}.{table_metadata.table_name}
- 表注释: {table_metadata.comment or "无"}
- 字段列表:
- {chr(10).join([f"{i+1}. {info}" for i, info in enumerate(fields_info)])}
- 请按以下JSON格式输出每个字段的分析结果:
- ```json
- {{
- "fields": [
- {{
- "name": "字段名",
- "comment": "中文注释(简洁明确,15字以内)",
- "is_enum": true/false,
- "enum_values": ["值1", "值2", "值3"] (如果是枚举),
- "enum_description": "枚举含义说明" (如果是枚举)
- }}
- ]
- }}
- ```
- 注释生成要求:
- 1. 如果原注释是英文,翻译为中文并改进
- 2. 根据字段名、类型和样例值推断字段含义
- 3. 识别可能的枚举字段(如状态、类型、级别等)
- 4. 枚举判断标准: VARCHAR类型 + 样例值重复度高 + 字段名暗示分类
- 5. 注释要贴近{business_context}的业务场景
- 请输出JSON格式的分析结果:"""
-
- return prompt
-
- async def _call_llm_with_retry(self, prompt: str, max_retries: int = 3) -> str:
- """带重试的LLM调用"""
- from schema_tools.config import SCHEMA_TOOLS_CONFIG
-
- for attempt in range(max_retries):
- try:
- # 使用vanna实例的chat_with_llm方法进行自由聊天
- # 这是专门用于生成训练数据的方法,不会查询向量数据库
- response = await asyncio.to_thread(
- self.vn.chat_with_llm,
- question=prompt,
- system_prompt="你是一个专业的数据库文档专家,专门负责生成高质量的中文数据库表和字段注释。"
- )
-
- if response and response.strip():
- return response.strip()
- else:
- raise ValueError("LLM返回空响应")
-
- except Exception as e:
- self.logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
- if attempt == max_retries - 1:
- raise
- await asyncio.sleep(1) # 等待1秒后重试
-
- raise Exception("LLM调用达到最大重试次数")
-
- def _extract_table_comment(self, llm_response: str) -> str:
- """从LLM响应中提取表注释"""
- # 简单的文本清理和提取逻辑
- lines = llm_response.strip().split('\n')
-
- # 查找包含实际注释的行
- for line in lines:
- line = line.strip()
- if line and not line.startswith('#') and not line.startswith('*'):
- # 移除可能的前缀
- prefixes = ['表注释:', '注释:', '说明:', '表说明:']
- for prefix in prefixes:
- if line.startswith(prefix):
- line = line[len(prefix):].strip()
-
- if line:
- return line[:200] # 限制长度
-
- return llm_response.strip()[:200]
-
- def _parse_field_batch_response(self, llm_response: str, fields: List[FieldInfo]) -> List[Dict[str, Any]]:
- """解析字段批量处理响应"""
- import json
- import re
-
- try:
- # 尝试提取JSON部分
- json_match = re.search(r'```json\s*(.*?)\s*```', llm_response, re.DOTALL)
- if json_match:
- json_str = json_match.group(1)
- else:
- # 如果没有代码块,尝试直接解析
- json_str = llm_response
-
- # 解析JSON
- parsed_data = json.loads(json_str)
- field_data = parsed_data.get('fields', [])
-
- # 映射到字段结果
- results = []
- for i, field in enumerate(fields):
- if i < len(field_data):
- data = field_data[i]
- results.append({
- 'success': True,
- 'comment': data.get('comment', field.name),
- 'is_enum': data.get('is_enum', False),
- 'enum_values': data.get('enum_values', []),
- 'enum_description': data.get('enum_description', '')
- })
- else:
- # 默认结果
- results.append({
- 'success': False,
- 'comment': field.original_comment or field.name,
- 'is_enum': False
- })
-
- return results
-
- except Exception as e:
- self.logger.error(f"解析字段批量响应失败: {e}")
- # 返回默认结果
- return [
- {
- 'success': False,
- 'comment': field.original_comment or field.name,
- 'is_enum': False,
- 'error': str(e)
- }
- for field in fields
- ]
-
- async def _validate_enum_suggestions(self, table_metadata, enum_suggestions: List[Dict]) -> List[Dict]:
- """验证枚举建议"""
- from schema_tools.tools.database_inspector import DatabaseInspectorTool
- from schema_tools.config import SCHEMA_TOOLS_CONFIG
-
- validated_enums = []
- inspector = ToolRegistry.get_tool("database_inspector")
- sample_limit = SCHEMA_TOOLS_CONFIG["enum_detection_sample_limit"]
-
- for enum_info in enum_suggestions:
- field_name = enum_info['field_name']
-
- try:
- # 查询字段的所有不同值
- query = f"""
- SELECT DISTINCT {field_name} as value, COUNT(*) as count
- FROM {table_metadata.full_name}
- WHERE {field_name} IS NOT NULL
- GROUP BY {field_name}
- ORDER BY count DESC
- LIMIT {sample_limit}
- """
-
- async with inspector.connection_pool.acquire() as conn:
- rows = await conn.fetch(query)
-
- actual_values = [str(row['value']) for row in rows]
-
- # 验证是否真的是枚举(不同值数量合理)
- max_enum_values = SCHEMA_TOOLS_CONFIG["enum_max_distinct_values"]
- if len(actual_values) <= max_enum_values:
- validated_enums.append({
- 'field_name': field_name,
- 'actual_values': actual_values,
- 'suggested_values': enum_info['suggested_values'],
- 'description': enum_info['enum_description'],
- 'value_counts': [(row['value'], row['count']) for row in rows]
- })
- self.logger.info(f"确认字段 {field_name} 为枚举类型,包含 {len(actual_values)} 个值")
- else:
- self.logger.info(f"字段 {field_name} 不同值过多({len(actual_values)}),不认为是枚举")
-
- except Exception as e:
- self.logger.warning(f"验证字段 {field_name} 的枚举建议失败: {e}")
-
- return validated_enums
-
- def _load_business_dictionary(self) -> str:
- """加载业务词典"""
- try:
- import os
- dict_file = os.path.join(os.path.dirname(__file__), '..', 'prompts', 'business_dictionary.txt')
- if os.path.exists(dict_file):
- with open(dict_file, 'r', encoding='utf-8') as f:
- content = f.read().strip()
- return f"\n业务词典:\n{content}\n" if content else ""
- return ""
- except Exception as e:
- self.logger.warning(f"加载业务词典失败: {e}")
- return ""
|