|
@@ -0,0 +1,525 @@
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import time
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
+
|
|
|
+from schema_tools.config import SCHEMA_TOOLS_CONFIG
|
|
|
+from schema_tools.validators import FileCountValidator
|
|
|
+from schema_tools.analyzers import MDFileAnalyzer, ThemeExtractor
|
|
|
+from schema_tools.utils.logger import setup_logging
|
|
|
+from core.vanna_llm_factory import create_vanna_instance
|
|
|
+
|
|
|
+
|
|
|
+class QuestionSQLGenerationAgent:
|
|
|
+ """Question-SQL生成Agent"""
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ output_dir: str,
|
|
|
+ table_list_file: str,
|
|
|
+ business_context: str,
|
|
|
+ db_name: str = None):
|
|
|
+ """
|
|
|
+ 初始化Agent
|
|
|
+
|
|
|
+ Args:
|
|
|
+ output_dir: 输出目录(包含DDL和MD文件)
|
|
|
+ table_list_file: 表清单文件路径
|
|
|
+ business_context: 业务上下文
|
|
|
+ db_name: 数据库名称(用于输出文件命名)
|
|
|
+ """
|
|
|
+ self.output_dir = Path(output_dir)
|
|
|
+ self.table_list_file = table_list_file
|
|
|
+ self.business_context = business_context
|
|
|
+ self.db_name = db_name or "db"
|
|
|
+
|
|
|
+ self.config = SCHEMA_TOOLS_CONFIG
|
|
|
+ self.logger = logging.getLogger("schema_tools.QSAgent")
|
|
|
+
|
|
|
+ # 初始化组件
|
|
|
+ self.validator = FileCountValidator()
|
|
|
+ self.md_analyzer = MDFileAnalyzer(output_dir)
|
|
|
+
|
|
|
+ # vanna实例和主题提取器将在需要时初始化
|
|
|
+ self.vn = None
|
|
|
+ self.theme_extractor = None
|
|
|
+
|
|
|
+ # 中间结果存储
|
|
|
+ self.intermediate_results = []
|
|
|
+ self.intermediate_file = None
|
|
|
+
|
|
|
+ async def generate(self) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 生成Question-SQL对
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 生成结果报告
|
|
|
+ """
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.logger.info("🚀 开始生成Question-SQL训练数据")
|
|
|
+
|
|
|
+ # 1. 验证文件数量
|
|
|
+ self.logger.info("📋 验证文件数量...")
|
|
|
+ validation_result = self.validator.validate(self.table_list_file, str(self.output_dir))
|
|
|
+
|
|
|
+ if not validation_result.is_valid:
|
|
|
+ self.logger.error(f"❌ 文件验证失败: {validation_result.error}")
|
|
|
+ if validation_result.missing_ddl:
|
|
|
+ self.logger.error(f"缺失DDL文件: {validation_result.missing_ddl}")
|
|
|
+ if validation_result.missing_md:
|
|
|
+ self.logger.error(f"缺失MD文件: {validation_result.missing_md}")
|
|
|
+ raise ValueError(f"文件验证失败: {validation_result.error}")
|
|
|
+
|
|
|
+ self.logger.info(f"✅ 文件验证通过: {validation_result.table_count}个表")
|
|
|
+
|
|
|
+ # 2. 读取所有MD文件内容
|
|
|
+ self.logger.info("📖 读取MD文件...")
|
|
|
+ md_contents = await self.md_analyzer.read_all_md_files()
|
|
|
+
|
|
|
+ # 3. 初始化LLM相关组件
|
|
|
+ self._initialize_llm_components()
|
|
|
+
|
|
|
+ # 4. 提取分析主题
|
|
|
+ self.logger.info("🎯 提取分析主题...")
|
|
|
+ themes = await self.theme_extractor.extract_themes(md_contents)
|
|
|
+ self.logger.info(f"✅ 成功提取 {len(themes)} 个分析主题")
|
|
|
+
|
|
|
+ for i, theme in enumerate(themes):
|
|
|
+ topic_name = theme.get('topic_name', theme.get('name', ''))
|
|
|
+ description = theme.get('description', '')
|
|
|
+ self.logger.info(f" {i+1}. {topic_name}: {description}")
|
|
|
+
|
|
|
+ # 5. 初始化中间结果文件
|
|
|
+ self._init_intermediate_file()
|
|
|
+
|
|
|
+ # 6. 处理每个主题
|
|
|
+ all_qs_pairs = []
|
|
|
+ failed_themes = []
|
|
|
+
|
|
|
+ # 根据配置决定是并行还是串行处理
|
|
|
+ max_concurrent = self.config['qs_generation'].get('max_concurrent_themes', 1)
|
|
|
+ if max_concurrent > 1:
|
|
|
+ results = await self._process_themes_parallel(themes, md_contents, max_concurrent)
|
|
|
+ else:
|
|
|
+ results = await self._process_themes_serial(themes, md_contents)
|
|
|
+
|
|
|
+ # 7. 整理结果
|
|
|
+ for result in results:
|
|
|
+ if result['success']:
|
|
|
+ all_qs_pairs.extend(result['qs_pairs'])
|
|
|
+ else:
|
|
|
+ failed_themes.append(result['theme_name'])
|
|
|
+
|
|
|
+ # 8. 保存最终结果
|
|
|
+ output_file = await self._save_final_results(all_qs_pairs)
|
|
|
+
|
|
|
+ # 8.5 生成metadata.txt文件
|
|
|
+ await self._generate_metadata_file(themes)
|
|
|
+
|
|
|
+ # 9. 清理中间文件
|
|
|
+ if not failed_themes: # 只有全部成功才清理
|
|
|
+ self._cleanup_intermediate_file()
|
|
|
+
|
|
|
+ # 10. 生成报告
|
|
|
+ end_time = time.time()
|
|
|
+ report = {
|
|
|
+ 'success': True,
|
|
|
+ 'total_themes': len(themes),
|
|
|
+ 'successful_themes': len(themes) - len(failed_themes),
|
|
|
+ 'failed_themes': failed_themes,
|
|
|
+ 'total_questions': len(all_qs_pairs),
|
|
|
+ 'output_file': str(output_file),
|
|
|
+ 'execution_time': end_time - start_time
|
|
|
+ }
|
|
|
+
|
|
|
+ self._print_summary(report)
|
|
|
+
|
|
|
+ return report
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.exception("❌ Question-SQL生成失败")
|
|
|
+
|
|
|
+ # 保存当前已生成的结果
|
|
|
+ if self.intermediate_results:
|
|
|
+ recovery_file = self._save_intermediate_results()
|
|
|
+ self.logger.warning(f"⚠️ 中间结果已保存到: {recovery_file}")
|
|
|
+
|
|
|
+ raise
|
|
|
+
|
|
|
+ def _initialize_llm_components(self):
|
|
|
+ """初始化LLM相关组件"""
|
|
|
+ if not self.vn:
|
|
|
+ self.logger.info("初始化LLM组件...")
|
|
|
+ self.vn = create_vanna_instance()
|
|
|
+ self.theme_extractor = ThemeExtractor(self.vn, self.business_context)
|
|
|
+
|
|
|
+ async def _process_themes_serial(self, themes: List[Dict], md_contents: str) -> List[Dict]:
|
|
|
+ """串行处理主题"""
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for i, theme in enumerate(themes):
|
|
|
+ self.logger.info(f"处理主题 {i+1}/{len(themes)}: {theme.get('topic_name', theme.get('name', ''))}")
|
|
|
+ result = await self._process_single_theme(theme, md_contents)
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ # 检查是否需要继续
|
|
|
+ if not result['success'] and not self.config['qs_generation']['continue_on_theme_error']:
|
|
|
+ self.logger.error(f"主题处理失败,停止处理")
|
|
|
+ break
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ async def _process_themes_parallel(self, themes: List[Dict], md_contents: str, max_concurrent: int) -> List[Dict]:
|
|
|
+ """并行处理主题"""
|
|
|
+ semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
+
|
|
|
+ async def process_with_semaphore(theme):
|
|
|
+ async with semaphore:
|
|
|
+ return await self._process_single_theme(theme, md_contents)
|
|
|
+
|
|
|
+ tasks = [process_with_semaphore(theme) for theme in themes]
|
|
|
+ results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
+
|
|
|
+ # 处理异常结果
|
|
|
+ processed_results = []
|
|
|
+ for i, result in enumerate(results):
|
|
|
+ if isinstance(result, Exception):
|
|
|
+ theme_name = themes[i].get('topic_name', themes[i].get('name', ''))
|
|
|
+ self.logger.error(f"主题 '{theme_name}' 处理异常: {result}")
|
|
|
+ processed_results.append({
|
|
|
+ 'success': False,
|
|
|
+ 'theme_name': theme_name,
|
|
|
+ 'error': str(result)
|
|
|
+ })
|
|
|
+ else:
|
|
|
+ processed_results.append(result)
|
|
|
+
|
|
|
+ return processed_results
|
|
|
+
|
|
|
+ async def _process_single_theme(self, theme: Dict, md_contents: str) -> Dict:
|
|
|
+ """处理单个主题"""
|
|
|
+ theme_name = theme.get('topic_name', theme.get('name', ''))
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.logger.info(f"🔍 开始处理主题: {theme_name}")
|
|
|
+
|
|
|
+ # 构建prompt
|
|
|
+ prompt = self._build_qs_generation_prompt(theme, md_contents)
|
|
|
+
|
|
|
+ # 调用LLM生成
|
|
|
+ response = await self._call_llm(prompt)
|
|
|
+
|
|
|
+ # 解析响应
|
|
|
+ qs_pairs = self._parse_qs_response(response)
|
|
|
+
|
|
|
+ # 验证和清理
|
|
|
+ validated_pairs = self._validate_qs_pairs(qs_pairs, theme['name'])
|
|
|
+
|
|
|
+ # 保存中间结果
|
|
|
+ await self._save_theme_results(theme_name, validated_pairs)
|
|
|
+
|
|
|
+ self.logger.info(f"✅ 主题 '{theme_name}' 处理成功,生成 {len(validated_pairs)} 个问题")
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'success': True,
|
|
|
+ 'theme_name': theme_name,
|
|
|
+ 'qs_pairs': validated_pairs
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"❌ 处理主题 '{theme_name}' 失败: {e}")
|
|
|
+ return {
|
|
|
+ 'success': False,
|
|
|
+ 'theme_name': theme_name,
|
|
|
+ 'error': str(e),
|
|
|
+ 'qs_pairs': []
|
|
|
+ }
|
|
|
+
|
|
|
+ def _build_qs_generation_prompt(self, theme: Dict, md_contents: str) -> str:
|
|
|
+ """构建Question-SQL生成的prompt"""
|
|
|
+ questions_count = self.config['qs_generation']['questions_per_theme']
|
|
|
+
|
|
|
+ # 兼容新旧格式
|
|
|
+ topic_name = theme.get('topic_name', theme.get('name', ''))
|
|
|
+ description = theme.get('description', '')
|
|
|
+ focus_areas = theme.get('focus_areas', theme.get('keywords', []))
|
|
|
+ related_tables = theme.get('related_tables', [])
|
|
|
+
|
|
|
+ prompt = f"""你是一位业务数据分析师,正在为{self.business_context}设计数据查询。
|
|
|
+
|
|
|
+当前分析主题:{topic_name}
|
|
|
+主题描述:{description}
|
|
|
+关注领域:{', '.join(focus_areas)}
|
|
|
+相关表:{', '.join(related_tables)}
|
|
|
+
|
|
|
+数据库表结构信息:
|
|
|
+{md_contents}
|
|
|
+
|
|
|
+请为这个主题生成 {questions_count} 个业务问题和对应的SQL查询。
|
|
|
+
|
|
|
+要求:
|
|
|
+1. 问题应该从业务角度出发,贴合主题要求,具有实际分析价值
|
|
|
+2. SQL必须使用PostgreSQL语法
|
|
|
+3. 考虑实际业务逻辑(如软删除使用 delete_ts IS NULL 条件)
|
|
|
+4. 使用中文别名提高可读性(使用 AS 指定列别名)
|
|
|
+5. 问题应该多样化,覆盖不同的分析角度
|
|
|
+6. 包含时间筛选、分组统计、排序、限制等不同类型的查询
|
|
|
+7. SQL语句末尾必须以分号结束
|
|
|
+
|
|
|
+输出JSON格式(注意SQL中的双引号需要转义):
|
|
|
+```json
|
|
|
+[
|
|
|
+ {{
|
|
|
+ "question": "具体的业务问题?",
|
|
|
+ "sql": "SELECT column AS 中文名 FROM table WHERE condition;"
|
|
|
+ }}
|
|
|
+]
|
|
|
+```
|
|
|
+
|
|
|
+生成的问题应该包括但不限于:
|
|
|
+- 趋势分析(按时间维度)
|
|
|
+- 对比分析(不同维度对比)
|
|
|
+- 排名统计(TOP N)
|
|
|
+- 汇总统计(总量、平均值等)
|
|
|
+- 明细查询(特定条件的详细数据)"""
|
|
|
+
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ async def _call_llm(self, prompt: str) -> str:
|
|
|
+ """调用LLM"""
|
|
|
+ try:
|
|
|
+ response = await asyncio.to_thread(
|
|
|
+ self.vn.chat_with_llm,
|
|
|
+ question=prompt,
|
|
|
+ system_prompt="你是一个专业的数据分析师,精通PostgreSQL语法,擅长设计有业务价值的数据查询。请严格按照JSON格式输出。"
|
|
|
+ )
|
|
|
+
|
|
|
+ if not response or not response.strip():
|
|
|
+ raise ValueError("LLM返回空响应")
|
|
|
+
|
|
|
+ return response.strip()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"LLM调用失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def _parse_qs_response(self, response: str) -> List[Dict[str, str]]:
|
|
|
+ """解析Question-SQL响应"""
|
|
|
+ try:
|
|
|
+ # 提取JSON部分
|
|
|
+ import re
|
|
|
+ json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
|
|
+ if json_match:
|
|
|
+ json_str = json_match.group(1)
|
|
|
+ else:
|
|
|
+ json_str = response
|
|
|
+
|
|
|
+ # 解析JSON
|
|
|
+ qs_pairs = json.loads(json_str)
|
|
|
+
|
|
|
+ if not isinstance(qs_pairs, list):
|
|
|
+ raise ValueError("响应不是列表格式")
|
|
|
+
|
|
|
+ return qs_pairs
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ self.logger.error(f"JSON解析失败: {e}")
|
|
|
+ self.logger.debug(f"原始响应: {response}")
|
|
|
+ raise ValueError(f"无法解析LLM响应为JSON格式: {e}")
|
|
|
+
|
|
|
+ def _validate_qs_pairs(self, qs_pairs: List[Dict], theme_name: str) -> List[Dict[str, str]]:
|
|
|
+ """验证和清理Question-SQL对"""
|
|
|
+ validated = []
|
|
|
+
|
|
|
+ for i, pair in enumerate(qs_pairs):
|
|
|
+ if not isinstance(pair, dict):
|
|
|
+ self.logger.warning(f"跳过无效格式的项 {i+1}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ question = pair.get('question', '').strip()
|
|
|
+ sql = pair.get('sql', '').strip()
|
|
|
+
|
|
|
+ if not question or not sql:
|
|
|
+ self.logger.warning(f"跳过空问题或SQL的项 {i+1}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 确保SQL以分号结束
|
|
|
+ if not sql.endswith(';'):
|
|
|
+ sql += ';'
|
|
|
+
|
|
|
+ validated.append({
|
|
|
+ 'question': question,
|
|
|
+ 'sql': sql
|
|
|
+ })
|
|
|
+
|
|
|
+ self.logger.info(f"主题 '{theme_name}': 验证通过 {len(validated)}/{len(qs_pairs)} 个问题")
|
|
|
+
|
|
|
+ return validated
|
|
|
+
|
|
|
+ def _init_intermediate_file(self):
|
|
|
+ """初始化中间结果文件"""
|
|
|
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
+ self.intermediate_file = self.output_dir / f"qs_intermediate_{timestamp}.json"
|
|
|
+ self.intermediate_results = []
|
|
|
+ self.logger.info(f"中间结果文件: {self.intermediate_file}")
|
|
|
+
|
|
|
+ async def _save_theme_results(self, theme_name: str, qs_pairs: List[Dict]):
|
|
|
+ """保存单个主题的结果"""
|
|
|
+ theme_result = {
|
|
|
+ "theme": theme_name,
|
|
|
+ "timestamp": datetime.now().isoformat(),
|
|
|
+ "questions_count": len(qs_pairs),
|
|
|
+ "questions": qs_pairs
|
|
|
+ }
|
|
|
+
|
|
|
+ self.intermediate_results.append(theme_result)
|
|
|
+
|
|
|
+ # 立即保存到中间文件
|
|
|
+ if self.config['qs_generation']['save_intermediate']:
|
|
|
+ try:
|
|
|
+ with open(self.intermediate_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(self.intermediate_results, f, ensure_ascii=False, indent=2)
|
|
|
+ self.logger.debug(f"中间结果已更新: {self.intermediate_file}")
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.warning(f"保存中间结果失败: {e}")
|
|
|
+
|
|
|
+ def _save_intermediate_results(self) -> Path:
|
|
|
+ """异常时保存中间结果"""
|
|
|
+ if not self.intermediate_results:
|
|
|
+ return None
|
|
|
+
|
|
|
+ recovery_file = self.output_dir / f"qs_recovery_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(recovery_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump({
|
|
|
+ "status": "interrupted",
|
|
|
+ "timestamp": datetime.now().isoformat(),
|
|
|
+ "completed_themes": len(self.intermediate_results),
|
|
|
+ "results": self.intermediate_results
|
|
|
+ }, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ return recovery_file
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"保存恢复文件失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def _save_final_results(self, all_qs_pairs: List[Dict]) -> Path:
|
|
|
+ """保存最终结果"""
|
|
|
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
+ output_file = self.output_dir / f"{self.config['qs_generation']['output_file_prefix']}_{self.db_name}_{timestamp}_pair.json"
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(all_qs_pairs, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ self.logger.info(f"✅ 最终结果已保存到: {output_file}")
|
|
|
+ return output_file
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"保存最终结果失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def _cleanup_intermediate_file(self):
|
|
|
+ """清理中间文件"""
|
|
|
+ if self.intermediate_file and self.intermediate_file.exists():
|
|
|
+ try:
|
|
|
+ self.intermediate_file.unlink()
|
|
|
+ self.logger.info("已清理中间文件")
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.warning(f"清理中间文件失败: {e}")
|
|
|
+
|
|
|
+ def _print_summary(self, report: Dict):
|
|
|
+ """打印总结信息"""
|
|
|
+ self.logger.info("=" * 60)
|
|
|
+ self.logger.info("📊 生成总结")
|
|
|
+ self.logger.info(f" ✅ 总主题数: {report['total_themes']}")
|
|
|
+ self.logger.info(f" ✅ 成功主题: {report['successful_themes']}")
|
|
|
+
|
|
|
+ if report['failed_themes']:
|
|
|
+ self.logger.info(f" ❌ 失败主题: {len(report['failed_themes'])}")
|
|
|
+ for theme in report['failed_themes']:
|
|
|
+ self.logger.info(f" - {theme}")
|
|
|
+
|
|
|
+ self.logger.info(f" 📝 总问题数: {report['total_questions']}")
|
|
|
+ self.logger.info(f" 📁 输出文件: {report['output_file']}")
|
|
|
+ self.logger.info(f" ⏱️ 执行时间: {report['execution_time']:.2f}秒")
|
|
|
+ self.logger.info("=" * 60)
|
|
|
+
|
|
|
+ async def _generate_metadata_file(self, themes: List[Dict]):
|
|
|
+ """生成metadata.txt文件,包含INSERT语句"""
|
|
|
+ metadata_file = self.output_dir / "metadata.txt"
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("-- Schema Tools生成的主题元数据\n")
|
|
|
+ f.write(f"-- 业务背景: {self.business_context}\n")
|
|
|
+ f.write(f"-- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
|
+ f.write(f"-- 数据库: {self.db_name}\n\n")
|
|
|
+
|
|
|
+ f.write("-- 创建表(如果不存在)\n")
|
|
|
+ f.write("CREATE TABLE IF NOT EXISTS metadata (\n")
|
|
|
+ f.write(" id SERIAL PRIMARY KEY,\n")
|
|
|
+ f.write(" topic_name VARCHAR(100) NOT NULL,\n")
|
|
|
+ f.write(" description TEXT,\n")
|
|
|
+ f.write(" related_tables TEXT[],\n")
|
|
|
+ f.write(" keywords TEXT[],\n")
|
|
|
+ f.write(" focus_areas TEXT[],\n")
|
|
|
+ f.write(" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP\n")
|
|
|
+ f.write(");\n\n")
|
|
|
+
|
|
|
+ f.write("-- 插入主题数据\n")
|
|
|
+ for theme in themes:
|
|
|
+ # 获取字段值,使用新格式
|
|
|
+ topic_name = theme.get('topic_name', theme.get('name', ''))
|
|
|
+ description = theme.get('description', '')
|
|
|
+
|
|
|
+ # 处理related_tables
|
|
|
+ related_tables = theme.get('related_tables', [])
|
|
|
+ if isinstance(related_tables, list):
|
|
|
+ tables_str = '{' + ','.join(related_tables) + '}'
|
|
|
+ else:
|
|
|
+ tables_str = '{}'
|
|
|
+
|
|
|
+ # 处理keywords
|
|
|
+ keywords = theme.get('keywords', [])
|
|
|
+ if isinstance(keywords, list):
|
|
|
+ keywords_str = '{' + ','.join(keywords) + '}'
|
|
|
+ else:
|
|
|
+ keywords_str = '{}'
|
|
|
+
|
|
|
+ # 处理focus_areas
|
|
|
+ focus_areas = theme.get('focus_areas', [])
|
|
|
+ if isinstance(focus_areas, list):
|
|
|
+ focus_areas_str = '{' + ','.join(focus_areas) + '}'
|
|
|
+ else:
|
|
|
+ focus_areas_str = '{}'
|
|
|
+
|
|
|
+ # 生成INSERT语句
|
|
|
+ f.write("INSERT INTO metadata(topic_name, description, related_tables, keywords, focus_areas) VALUES\n")
|
|
|
+ f.write("(\n")
|
|
|
+ f.write(f" '{self._escape_sql_string(topic_name)}',\n")
|
|
|
+ f.write(f" '{self._escape_sql_string(description)}',\n")
|
|
|
+ f.write(f" '{tables_str}',\n")
|
|
|
+ f.write(f" '{keywords_str}',\n")
|
|
|
+ f.write(f" '{focus_areas_str}'\n")
|
|
|
+ f.write(");\n\n")
|
|
|
+
|
|
|
+ self.logger.info(f"✅ metadata.txt文件已生成: {metadata_file}")
|
|
|
+ return metadata_file
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"生成metadata.txt文件失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _escape_sql_string(self, value: str) -> str:
|
|
|
+ """转义SQL字符串中的特殊字符"""
|
|
|
+ if not value:
|
|
|
+ return ""
|
|
|
+ # 转义单引号
|
|
|
+ return value.replace("'", "''")
|