|
@@ -1,8 +1,8 @@
|
|
import logging
|
|
import logging
|
|
-from typing import Dict, List, Optional, Any
|
|
|
|
|
|
+from typing import Dict, List, Optional, Any, Union
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
import json
|
|
import json
|
|
-from app.core.llm.llm_service import llm_client
|
|
|
|
|
|
+from app.core.llm.llm_service import llm_client, llm_sql
|
|
from app.core.graph.graph_operations import connect_graph, create_or_get_node, get_node, relationship_exists
|
|
from app.core.graph.graph_operations import connect_graph, create_or_get_node, get_node, relationship_exists
|
|
from app.core.meta_data import translate_and_parse, get_formatted_time
|
|
from app.core.meta_data import translate_and_parse, get_formatted_time
|
|
from py2neo import Relationship
|
|
from py2neo import Relationship
|
|
@@ -611,36 +611,115 @@ class DataFlowService:
|
|
raise e
|
|
raise e
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
- def create_script(request_data: str) -> str:
|
|
|
|
|
|
+ def create_script(request_data: Union[Dict[str, Any], str]) -> str:
|
|
"""
|
|
"""
|
|
- 使用Deepseek模型生成脚本
|
|
|
|
|
|
+ 使用Deepseek模型生成SQL脚本
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- request_data: 请求数据,用户需求的文本描述
|
|
|
|
|
|
+ request_data: 包含input, output, request_content的请求数据字典,或JSON字符串
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
- 生成的脚本内容(TXT格式)
|
|
|
|
|
|
+ 生成的SQL脚本内容
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
- # 构建prompt
|
|
|
|
- prompt_parts = []
|
|
|
|
|
|
+ logger.info(f"开始处理脚本生成请求: {request_data}")
|
|
|
|
+ logger.info(f"request_data类型: {type(request_data)}")
|
|
|
|
+
|
|
|
|
+ # 类型检查和处理
|
|
|
|
+ if isinstance(request_data, str):
|
|
|
|
+ logger.warning(f"request_data是字符串,尝试解析为JSON: {request_data}")
|
|
|
|
+ try:
|
|
|
|
+ import json
|
|
|
|
+ request_data = json.loads(request_data)
|
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
|
+ raise ValueError(f"无法解析request_data为JSON: {str(e)}")
|
|
|
|
+
|
|
|
|
+ if not isinstance(request_data, dict):
|
|
|
|
+ raise ValueError(f"request_data必须是字典类型,实际类型: {type(request_data)}")
|
|
|
|
+
|
|
|
|
+ # 1. 从传入的request_data中解析input, output, request_content内容
|
|
|
|
+ input_data = request_data.get('input', '')
|
|
|
|
+ output_data = request_data.get('output', '')
|
|
|
|
+
|
|
|
|
+ request_content = request_data.get('request_data', '')
|
|
|
|
+
|
|
|
|
+ # 如果request_content是HTML格式,提取纯文本
|
|
|
|
+ if request_content and (request_content.startswith('<p>') or '<' in request_content):
|
|
|
|
+ # 简单的HTML标签清理
|
|
|
|
+ import re
|
|
|
|
+ request_content = re.sub(r'<[^>]+>', '', request_content).strip()
|
|
|
|
+
|
|
|
|
+ if not input_data or not output_data or not request_content:
|
|
|
|
+ raise ValueError(f"缺少必要参数:input='{input_data}', output='{output_data}', request_content='{request_content[:100] if request_content else ''}' 不能为空")
|
|
|
|
+
|
|
|
|
+ logger.info(f"解析得到 - input: {input_data}, output: {output_data}, request_content: {request_content}")
|
|
|
|
+
|
|
|
|
+ # 2. 解析input中的多个数据表并生成源表DDL
|
|
|
|
+ source_tables_ddl = []
|
|
|
|
+ input_tables = []
|
|
|
|
+ if input_data:
|
|
|
|
+ tables = [table.strip() for table in input_data.split(',') if table.strip()]
|
|
|
|
+ for table in tables:
|
|
|
|
+ ddl = DataFlowService._parse_table_and_get_ddl(table, 'input')
|
|
|
|
+ if ddl:
|
|
|
|
+ input_tables.append(table)
|
|
|
|
+ source_tables_ddl.append(ddl)
|
|
|
|
+ else:
|
|
|
|
+ logger.warning(f"无法获取输入表 {table} 的DDL结构")
|
|
|
|
+
|
|
|
|
+ # 3. 解析output中的数据表并生成目标表DDL
|
|
|
|
+ target_table_ddl = ""
|
|
|
|
+ if output_data:
|
|
|
|
+ target_table_ddl = DataFlowService._parse_table_and_get_ddl(output_data.strip(), 'output')
|
|
|
|
+ if not target_table_ddl:
|
|
|
|
+ logger.warning(f"无法获取输出表 {output_data} 的DDL结构")
|
|
|
|
|
|
- # 添加系统提示
|
|
|
|
- prompt_parts.append("请根据以下需求生成相应的数据处理脚本:")
|
|
|
|
|
|
+ # 4. 按照Deepseek-prompt.txt的框架构建提示语
|
|
|
|
+ prompt_parts = []
|
|
|
|
|
|
- # 直接将request_data作为文本描述添加到prompt中
|
|
|
|
- prompt_parts.append(request_data)
|
|
|
|
|
|
+ # 开场白 - 角色定义
|
|
|
|
+ prompt_parts.append("你是一名数据库工程师,正在构建一个PostgreSQL数据中的汇总逻辑。请为以下需求生成一段标准的 PostgreSQL SQL 脚本:")
|
|
|
|
+
|
|
|
|
+ # 动态生成源表部分(第1点)
|
|
|
|
+ for i, (table, ddl) in enumerate(zip(input_tables, source_tables_ddl), 1):
|
|
|
|
+ table_name = table.split(':')[-1] if ':' in table else table
|
|
|
|
+ prompt_parts.append(f"{i}.有一个源表: {table_name},它的定义语句如下:")
|
|
|
|
+ prompt_parts.append(ddl)
|
|
|
|
+ prompt_parts.append("") # 添加空行分隔
|
|
|
|
+
|
|
|
|
+ # 动态生成目标表部分(第2点)
|
|
|
|
+ if target_table_ddl:
|
|
|
|
+ target_table_name = output_data.split(':')[-1] if ':' in output_data else output_data
|
|
|
|
+ next_index = len(input_tables) + 1
|
|
|
|
+ prompt_parts.append(f"{next_index}.有一个目标表:{target_table_name},它的定义语句如下:")
|
|
|
|
+ prompt_parts.append(target_table_ddl)
|
|
|
|
+ prompt_parts.append("") # 添加空行分隔
|
|
|
|
+
|
|
|
|
+ # 动态生成处理逻辑部分(第3点)
|
|
|
|
+ next_index = len(input_tables) + 2 if target_table_ddl else len(input_tables) + 1
|
|
|
|
+ prompt_parts.append(f"{next_index}.处理逻辑为:{request_content}")
|
|
|
|
+ prompt_parts.append("") # 添加空行分隔
|
|
|
|
+
|
|
|
|
+ # 固定的技术要求部分(第4-8点)
|
|
|
|
+ tech_requirements = [
|
|
|
|
+ f"{next_index + 1}.脚本应使用标准的 PostgreSQL 语法,适合在 Airflow、Python 脚本、或调度系统中调用;",
|
|
|
|
+ f"{next_index + 2}.无需使用 UPSERT 或 ON CONFLICT",
|
|
|
|
+ f"{next_index + 3}.请直接输出SQL,无需进行解释。",
|
|
|
|
+ f"{next_index + 4}.请给这段sql起个英文名,不少于三个英文单词,使用\"_\"分隔,采用蛇形命名法。把sql的名字作为注释写在返回的sql中。",
|
|
|
|
+ f"{next_index + 5}.生成的sql在向目标表插入数据的时候,向create_time字段写入当前日期时间now(),不用处理update_time字段"
|
|
|
|
+ ]
|
|
|
|
|
|
- # 添加格式要求
|
|
|
|
- prompt_parts.append("\n请生成完整可执行的脚本代码,包含必要的注释和错误处理。")
|
|
|
|
|
|
+ prompt_parts.extend(tech_requirements)
|
|
|
|
|
|
- # 组合prompt
|
|
|
|
- full_prompt = "\n\n".join(prompt_parts)
|
|
|
|
|
|
+ # 组合完整的提示语
|
|
|
|
+ full_prompt = "\n".join(prompt_parts)
|
|
|
|
|
|
- logger.info(f"开始调用Deepseek模型生成脚本,prompt长度: {len(full_prompt)}")
|
|
|
|
|
|
+ logger.info(f"构建的完整提示语长度: {len(full_prompt)}")
|
|
|
|
+ logger.info(f"完整提示语内容: {full_prompt}")
|
|
|
|
|
|
- # 调用LLM服务
|
|
|
|
- script_content = llm_client(full_prompt)
|
|
|
|
|
|
+ # 5. 调用LLM生成SQL脚本
|
|
|
|
+ logger.info("开始调用Deepseek模型生成SQL脚本")
|
|
|
|
+ script_content = llm_sql(full_prompt)
|
|
|
|
|
|
if not script_content:
|
|
if not script_content:
|
|
raise ValueError("Deepseek模型返回空内容")
|
|
raise ValueError("Deepseek模型返回空内容")
|
|
@@ -649,14 +728,113 @@ class DataFlowService:
|
|
if not isinstance(script_content, str):
|
|
if not isinstance(script_content, str):
|
|
script_content = str(script_content)
|
|
script_content = str(script_content)
|
|
|
|
|
|
- logger.info(f"脚本生成成功,内容长度: {len(script_content)}")
|
|
|
|
|
|
+ logger.info(f"SQL脚本生成成功,内容长度: {len(script_content)}")
|
|
|
|
|
|
return script_content
|
|
return script_content
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logger.error(f"生成脚本失败: {str(e)}")
|
|
|
|
|
|
+ logger.error(f"生成SQL脚本失败: {str(e)}")
|
|
raise e
|
|
raise e
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def _parse_table_and_get_ddl(table_str: str, table_type: str) -> str:
|
|
|
|
+ """
|
|
|
|
+ 解析表格式(A:B)并从Neo4j查询元数据生成DDL
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ table_str: 表格式字符串,格式为"label:en_name"
|
|
|
|
+ table_type: 表类型,用于日志记录(input/output)
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ DDL格式的表结构字符串
|
|
|
|
+ """
|
|
|
|
+ try:
|
|
|
|
+ # 解析A:B格式
|
|
|
|
+ if ':' not in table_str:
|
|
|
|
+ logger.error(f"表格式错误,应为'label:en_name'格式: {table_str}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
|
|
+ parts = table_str.split(':', 1)
|
|
|
|
+ if len(parts) != 2:
|
|
|
|
+ logger.error(f"表格式解析失败: {table_str}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
|
|
+ label = parts[0].strip()
|
|
|
|
+ en_name = parts[1].strip()
|
|
|
|
+
|
|
|
|
+ if not label or not en_name:
|
|
|
|
+ logger.error(f"标签或英文名为空: label={label}, en_name={en_name}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
|
|
+ logger.info(f"开始查询{table_type}表: label={label}, en_name={en_name}")
|
|
|
|
+
|
|
|
|
+ # 从Neo4j查询节点及其关联的元数据
|
|
|
|
+ with connect_graph().session() as session:
|
|
|
|
+ # 查询节点及其关联的元数据
|
|
|
|
+ cypher = f"""
|
|
|
|
+ MATCH (n:{label} {{en_name: $en_name}})
|
|
|
|
+ OPTIONAL MATCH (n)-[:INCLUDES]->(m:DataMeta)
|
|
|
|
+ RETURN n, collect(m) as metadata
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ result = session.run(cypher, en_name=en_name)
|
|
|
|
+ record = result.single()
|
|
|
|
+
|
|
|
|
+ if not record:
|
|
|
|
+ logger.error(f"未找到节点: label={label}, en_name={en_name}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
|
|
+ node = record['n']
|
|
|
|
+ metadata = record['metadata']
|
|
|
|
+
|
|
|
|
+ logger.info(f"找到节点,关联元数据数量: {len(metadata)}")
|
|
|
|
+
|
|
|
|
+ # 生成DDL格式的表结构
|
|
|
|
+ ddl_lines = []
|
|
|
|
+ ddl_lines.append(f"CREATE TABLE {en_name} (")
|
|
|
|
+
|
|
|
|
+ if metadata:
|
|
|
|
+ column_definitions = []
|
|
|
|
+ for meta in metadata:
|
|
|
|
+ if meta: # 确保meta不为空
|
|
|
|
+ meta_props = dict(meta)
|
|
|
|
+ column_name = meta_props.get('en_name', meta_props.get('name', 'unknown_column'))
|
|
|
|
+ data_type = meta_props.get('data_type', 'VARCHAR(255)')
|
|
|
|
+ comment = meta_props.get('name', '')
|
|
|
|
+
|
|
|
|
+ # 构建列定义
|
|
|
|
+ column_def = f" {column_name} {data_type}"
|
|
|
|
+ if comment:
|
|
|
|
+ column_def += f" COMMENT '{comment}'"
|
|
|
|
+
|
|
|
|
+ column_definitions.append(column_def)
|
|
|
|
+
|
|
|
|
+ if column_definitions:
|
|
|
|
+ ddl_lines.append(",\n".join(column_definitions))
|
|
|
|
+ else:
|
|
|
|
+ ddl_lines.append(" id BIGINT PRIMARY KEY COMMENT '主键ID'")
|
|
|
|
+ else:
|
|
|
|
+ # 如果没有元数据,添加默认列
|
|
|
|
+ ddl_lines.append(" id BIGINT PRIMARY KEY COMMENT '主键ID'")
|
|
|
|
+
|
|
|
|
+ ddl_lines.append(");")
|
|
|
|
+
|
|
|
|
+ # 添加表注释
|
|
|
|
+ node_props = dict(node)
|
|
|
|
+ table_comment = node_props.get('name', node_props.get('describe', en_name))
|
|
|
|
+ if table_comment and table_comment != en_name:
|
|
|
|
+ ddl_lines.append(f"COMMENT ON TABLE {en_name} IS '{table_comment}';")
|
|
|
|
+
|
|
|
|
+ ddl_content = "\n".join(ddl_lines)
|
|
|
|
+ logger.info(f"{table_type}表DDL生成成功: {en_name}")
|
|
|
|
+ logger.debug(f"生成的DDL: {ddl_content}")
|
|
|
|
+
|
|
|
|
+ return ddl_content
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"解析表格式和生成DDL失败: {str(e)}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
@staticmethod
|
|
@staticmethod
|
|
def _handle_script_relationships(data: Dict[str, Any],dataflow_name:str,name_en:str):
|
|
def _handle_script_relationships(data: Dict[str, Any],dataflow_name:str,name_en:str):
|
|
"""
|
|
"""
|