Explorar o código

在加载数据之前建表,根据neo4j生成建表语句和字段注释语句。

wangxq hai 1 mes
pai
achega
7d361fb7c8
Modificáronse 2 ficheiros con 216 adicións e 14 borrados
  1. 64 13
      dataops_scripts/load_file.py
  2. 152 1
      dataops_scripts/script_utils.py

+ 64 - 13
dataops_scripts/load_file.py

@@ -52,6 +52,7 @@ def get_pg_conn():
 def get_table_columns(table_name):
     """
     获取表的列信息,包括列名和注释
+    如果表不存在,则尝试从Neo4j创建表
     
     返回:
         dict: {列名: 列注释} 的字典
@@ -59,18 +60,66 @@ def get_table_columns(table_name):
     conn = get_pg_conn()
     cursor = conn.cursor()
     try:
+        # 首先检查表是否存在
+        # 解析表名,可能包含schema(如ods.table_name)
+        if '.' in table_name:
+            schema_name, actual_table_name = table_name.split('.', 1)
+            cursor.execute("""
+                SELECT EXISTS (
+                    SELECT FROM information_schema.tables 
+                    WHERE table_schema = %s AND table_name = %s
+                );
+            """, (schema_name.lower(), actual_table_name.lower()))
+        else:
+            cursor.execute("""
+                SELECT EXISTS (
+                    SELECT FROM information_schema.tables 
+                    WHERE table_name = %s
+                );
+            """, (table_name.lower(),))
+        
+        table_exists = cursor.fetchone()[0]
+        
+        if not table_exists:
+            logger.warning(f"表 '{table_name}' 不存在,尝试从Neo4j创建表")
+            # 调用script_utils中的create_table_from_neo4j函数
+            try:
+                if script_utils.create_table_from_neo4j(table_name):
+                    logger.info(f"成功从Neo4j创建表 '{table_name}'")
+                else:
+                    logger.error(f"从Neo4j创建表 '{table_name}' 失败")
+                    return {}
+            except Exception as create_err:
+                logger.error(f"调用create_table_from_neo4j创建表 '{table_name}' 时出错: {str(create_err)}")
+                return {}
+        
         # 查询表列信息
-        cursor.execute("""
-            SELECT 
-                column_name, 
-                col_description((table_schema || '.' || table_name)::regclass::oid, ordinal_position) as column_comment
-            FROM 
-                information_schema.columns
-            WHERE 
-                table_name = %s
-            ORDER BY 
-                ordinal_position
-        """, (table_name.lower(),))
+        # 根据表名是否包含schema来构建查询
+        if '.' in table_name:
+            schema_name, actual_table_name = table_name.split('.', 1)
+            cursor.execute("""
+                SELECT 
+                    column_name, 
+                    col_description((table_schema || '.' || table_name)::regclass::oid, ordinal_position) as column_comment
+                FROM 
+                    information_schema.columns
+                WHERE 
+                    table_schema = %s AND table_name = %s
+                ORDER BY 
+                    ordinal_position
+            """, (schema_name.lower(), actual_table_name.lower()))
+        else:
+            cursor.execute("""
+                SELECT 
+                    column_name, 
+                    col_description((table_schema || '.' || table_name)::regclass::oid, ordinal_position) as column_comment
+                FROM 
+                    information_schema.columns
+                WHERE 
+                    table_name = %s
+                ORDER BY 
+                    ordinal_position
+            """, (table_name.lower(),))
         
         columns = {}
         empty_comment_columns = []  # 记录注释为空的列
@@ -100,8 +149,10 @@ def get_table_columns(table_name):
         logger.error(f"获取表 '{table_name}' 的列信息时出错: {str(e)}")
         return {}
     finally:
-        cursor.close()
-        conn.close()
+        if cursor:
+            cursor.close()
+        if conn:
+            conn.close()
 
 def match_file_columns(file_headers, table_columns):
     """

+ 152 - 1
dataops_scripts/script_utils.py

@@ -525,4 +525,155 @@ def get_config_param(param_name, default_value=None):
         return getattr(config_module, param_name)
     except Exception as e:
         logger.warning(f"获取配置参数 {param_name} 失败: {str(e)},使用默认值: {default_value}")
-        return default_value
+        return default_value
+
+def create_table_from_neo4j(en_name: str):
+    """
+    根据Neo4j中的表定义创建PostgreSQL表
+    
+    参数:
+        en_name (str): 表的英文名称
+    
+    返回:
+        bool: 成功返回True,失败返回False
+    """
+    driver = None
+    conn = None
+    cur = None
+    
+    try:
+        # 使用script_utils中的方法获取连接
+        driver = get_neo4j_driver()
+        pg_config = get_pg_config()
+        
+        import psycopg2
+        conn = psycopg2.connect(**pg_config)
+        cur = conn.cursor()
+
+        with driver.session() as session:
+            # 1. 查找目标表节点(DataResource/DataModel/DataMetric)
+            result = session.run("""
+                MATCH (t)
+                WHERE t.en_name = $en_name AND (t:DataResource OR t:DataModel OR t:DataMetric)
+                RETURN labels(t) AS labels, t.en_name AS en_name, t.name AS name, id(t) AS node_id
+            """, en_name=en_name)
+
+            record = result.single()
+            if not record:
+                logger.error(f"未找到名为 {en_name} 的表节点")
+                return False
+
+            labels = record["labels"]
+            table_en_name = record["en_name"]
+            table_cn_name = record["name"]
+            node_id = record["node_id"]
+
+            schema = "ods" if "DataResource" in labels else "ads"
+
+            # 2. 查找所有字段(HAS_COLUMN关系)并按Column节点的系统id排序
+            column_result = session.run("""
+                MATCH (t)-[:HAS_COLUMN]->(c:Column)
+                WHERE id(t) = $node_id
+                RETURN c.en_name AS en_name, c.data_type AS data_type, 
+                       c.name AS name, c.is_pk AS is_pk, id(c) AS column_id
+                ORDER BY id(c) ASC
+            """, node_id=node_id)
+
+            columns = column_result.data()
+            if not columns:
+                logger.error(f"未找到表 {en_name} 的字段信息")
+                return False
+
+            # 3. 构造 DDL
+            ddl_lines = []
+            pk_fields = []
+            existing_fields = set()
+
+            for col in columns:
+                col_line = f'{col["en_name"]} {col["data_type"]}'
+                ddl_lines.append(col_line)
+                existing_fields.add(col["en_name"].lower())
+                if col.get("is_pk", False):
+                    pk_fields.append(f'{col["en_name"]}')
+
+            # 检查并添加 create_time 和 update_time 字段
+            if 'create_time' not in existing_fields:
+                ddl_lines.append('create_time timestamp')
+            if 'update_time' not in existing_fields:
+                ddl_lines.append('update_time timestamp')
+
+            if pk_fields:
+                ddl_lines.append(f'PRIMARY KEY ({", ".join(pk_fields)})')
+
+            full_table_name = f"{schema}.{table_en_name}"
+            ddl = f'CREATE SCHEMA IF NOT EXISTS {schema};\n'
+            ddl += f'CREATE TABLE IF NOT EXISTS {full_table_name} (\n  '
+            ddl += ",\n  ".join(ddl_lines)
+            ddl += "\n);"
+
+            # 生成表注释SQL
+            table_comment_sql = f"COMMENT ON TABLE {full_table_name} IS '{table_cn_name}';"
+            
+            # 生成字段注释SQL
+            column_comment_sqls = []
+            for col in columns:
+                if col["name"]:  # 如果有中文名称
+                    column_comment_sql = f"COMMENT ON COLUMN {full_table_name}.{col['en_name']} IS '{col['name']}';"
+                    column_comment_sqls.append(column_comment_sql)
+
+            logger.info(f"DDL: {ddl}")
+            logger.info(f"表注释SQL: {table_comment_sql}")
+            if column_comment_sqls:
+                logger.info("字段注释SQL:")
+                for comment_sql in column_comment_sqls:
+                    logger.info(f"  {comment_sql}")
+
+            # 4. 执行 DDL
+            try:
+                # 先检查表是否已经存在
+                check_table_sql = """
+                    SELECT EXISTS (
+                        SELECT FROM information_schema.tables 
+                        WHERE table_schema = %s AND table_name = %s
+                    );
+                """
+                cur.execute(check_table_sql, (schema, table_en_name))
+                table_exists = cur.fetchone()[0]
+                
+                if table_exists:
+                    logger.info(f"表 {full_table_name} 已存在,跳过创建")
+                    return True
+                else:
+                    # 执行创建表的DDL
+                    cur.execute(ddl)
+                    logger.info(f"成功创建新表: {full_table_name}")
+                    
+                    # 执行表注释
+                    cur.execute(table_comment_sql)
+                    logger.info(f"已添加表注释")
+                    
+                    # 执行字段注释
+                    for comment_sql in column_comment_sqls:
+                        cur.execute(comment_sql)
+                    logger.info(f"已添加 {len(column_comment_sqls)} 个字段注释")
+                    
+                    conn.commit()
+                    return True
+                    
+            except Exception as e:
+                logger.error(f"执行DDL失败: {e}")
+                conn.rollback()
+                return False
+
+    except Exception as e:
+        logger.error(f"创建表 {en_name} 时发生错误: {str(e)}")
+        if conn:
+            conn.rollback()
+        return False
+    finally:
+        if cur:
+            cur.close()
+        if conn:
+            conn.close()
+        if driver:
+            driver.close()