ソースを参照

修改了load_data和load_file时的schema问题。

wangxq 1 ヶ月 前
コミット
8029197e8c
4 ファイル変更111 行追加52 行削除
  1. 3 1
      .gitignore
  2. 1 1
      dags/config.py
  3. 87 36
      dataops_scripts/load_data.py
  4. 20 14
      dataops_scripts/load_file.py

+ 3 - 1
.gitignore

@@ -25,4 +25,6 @@ Thumbs.db
 node_modules/
 
 # 忽略 JetBrains IDE 配置
-.idea/
+.idea/
+
+/test

+ 1 - 1
dags/config.py

@@ -38,7 +38,7 @@ DATAOPS_DAGS_PATH = os.path.join(AIRFLOW_BASE_PATH, 'dags')
 SCRIPTS_BASE_PATH = os.path.join(AIRFLOW_BASE_PATH, 'dataops_scripts')
 
 # 上传的CSV/EXCEL文件的基准上传路径
-STRUCTURE_UPLOAD_BASE_PATH ="/data/csv"
+STRUCTURE_UPLOAD_BASE_PATH ="/data/upload"
 STRUCTURE_UPLOAD_ARCHIVE_BASE_PATH ="/data/archive"
 
 # 本地开发环境脚本路径(如果需要区分环境)

+ 87 - 36
dataops_scripts/load_data.py

@@ -86,7 +86,7 @@ def get_source_database_info(table_name, script_name=None):
                 "username": record.get("username"),
                 "password": record.get("password"),
                 "db_type": record.get("db_type"),
-                "schema": record.get("schema", "public"),
+                "schema": record.get("schema"),
                 "source_table": record.get("source_table"),
                 "labels": record.get("labels", [])
             }
@@ -139,10 +139,13 @@ def get_target_database_info():
             "username": pg_config.get("user"),
             "password": pg_config.get("password"),
             "database": pg_config.get("database"),
-            "db_type": "postgresql",
-            "schema": "public"
+            "db_type": "postgresql"
         }
         
+        # 如果配置中有schema,则添加到连接信息中
+        if "schema" in pg_config and pg_config["schema"]:
+            database_info["schema"] = pg_config["schema"]
+        
         logger.info(f"成功获取目标数据库连接信息: {database_info['host']}:{database_info['port']}/{database_info['database']}")
         return database_info
     except Exception as e:
@@ -187,7 +190,7 @@ def get_sqlalchemy_engine(db_info):
         logger.error(traceback.format_exc())
         return None
 
-def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema="public"):
+def create_table_if_not_exists(source_engine, target_engine, source_table, target_table, schema=None):
     """
     如果目标表不存在,则从源表复制表结构创建目标表
     
@@ -196,7 +199,7 @@ def create_table_if_not_exists(source_engine, target_engine, source_table, targe
         target_engine: 目标数据库引擎
         source_table: 源表名
         target_table: 目标表名
-        schema: 模式名称
+        schema: 模式名称,如果为None或空字符串则使用"ods"
         
     返回:
         bool: 操作是否成功
@@ -207,12 +210,26 @@ def create_table_if_not_exists(source_engine, target_engine, source_table, targe
     logger.info(f"检查目标表 {target_table} 是否存在,不存在则创建")
     
     try:
+        # 处理schema参数
+        if schema == "" or schema is None:
+            # 如果传递的schema为空,使用"ods"
+            schema = "ods"
+            logger.info(f"schema参数为空,使用默认schema: {schema}")
+        else:
+            # 如果传递的schema不为空,使用传递的schema
+            if schema != "ods":
+                logger.warning(f"使用非标准schema: {schema},建议使用'ods'作为目标schema")
+            logger.info(f"使用传递的schema: {schema}")
+        
+        table_display_name = f"{schema}.{target_table}"
+        logger.info(f"目标表完整名称: {table_display_name}")
+        
         # 检查目标表是否存在
         target_inspector = inspect(target_engine)
         target_exists = target_inspector.has_table(target_table, schema=schema)
         
         if target_exists:
-            logger.info(f"目标表 {target_table} 已存在,无需创建")
+            logger.info(f"目标表 {table_display_name} 已存在,无需创建")
             return True
         
         # 目标表不存在,从源表获取表结构
@@ -253,15 +270,16 @@ def create_table_if_not_exists(source_engine, target_engine, source_table, targe
         if not has_create_time:
             from sqlalchemy import TIMESTAMP
             columns.append(Column('create_time', TIMESTAMP, nullable=True))
-            logger.info(f"为表 {target_table} 添加 create_time 字段")
+            logger.info(f"为表 {table_display_name} 添加 create_time 字段")
         
         # 如果不存在update_time字段,则添加
         if not has_update_time:
             from sqlalchemy import TIMESTAMP
             columns.append(Column('update_time', TIMESTAMP, nullable=True))
-            logger.info(f"为表 {target_table} 添加 update_time 字段")
+            logger.info(f"为表 {table_display_name} 添加 update_time 字段")
         
         # 定义目标表结构,让SQLAlchemy处理数据类型映射
+        # 现在schema总是有值(至少是"ods")
         table_def = Table(
             target_table,
             metadata,
@@ -271,7 +289,7 @@ def create_table_if_not_exists(source_engine, target_engine, source_table, targe
         
         # 在目标数据库中创建表
         metadata.create_all(target_engine)
-        logger.info(f"成功在目标数据库中创建表 {schema}.{target_table}")
+        logger.info(f"成功在目标数据库中创建表 {table_display_name}")
         
         return True
     except Exception as e:
@@ -315,23 +333,43 @@ def load_data_from_source(table_name, exec_date=None, update_mode=None, script_n
         if not source_engine or not target_engine:
             raise Exception("无法创建数据库引擎,无法加载数据")
 
-        # 获取源表名
+        # 获取源表名和源schema
         source_table = source_db_info.get("source_table", table_name) or table_name
+        source_schema = source_db_info.get("schema")
+        
+        # 构建完整的源表名
+        if source_schema:
+            full_source_table_name = f"{source_schema}.{source_table}"
+        else:
+            full_source_table_name = source_table
+        
+        logger.info(f"源表完整名称: {full_source_table_name}")
 
-        # 确保目标表存在
-        if not create_table_if_not_exists(source_engine, target_engine, source_table, table_name):
-            raise Exception(f"无法创建目标表 {table_name},无法加载数据")
+        # 获取目标schema
+        target_schema = target_db_info.get("schema")
+        
+        # 构建完整的目标表名
+        if target_schema:
+            full_table_name = f"{target_schema}.{table_name}"
+        else:
+            full_table_name = table_name
+        
+        logger.info(f"目标表完整名称: {full_table_name}")
+        
+        # 确保目标表存在 - create_table_if_not_exists必须使用"ods"作为schema
+        if not create_table_if_not_exists(source_engine, target_engine, full_source_table_name, table_name, "ods"):
+            raise Exception(f"无法创建目标表 {full_table_name},无法加载数据")
 
         # 根据更新模式处理数据
         if update_mode == "full_refresh":
             # 执行全量刷新,清空表
-            logger.info(f"执行全量刷新,清空表 {table_name}")
+            logger.info(f"执行全量刷新,清空表 {full_table_name}")
             with target_engine.begin() as conn:  # 使用begin()自动管理事务
-                conn.execute(f"TRUNCATE TABLE {table_name}")
-            logger.info(f"成功清空表 {table_name}")
+                conn.execute(text(f"TRUNCATE TABLE {full_table_name}"))
+            logger.info(f"成功清空表 {full_table_name}")
 
             # 构建全量查询
-            query = f"SELECT * FROM {source_table}"
+            query = f"SELECT * FROM {full_source_table_name}"
         else:
             # 增量更新,需要获取目标日期列和日期范围
             target_dt_column = get_target_dt_column(table_name, script_name)
@@ -348,13 +386,13 @@ def load_data_from_source(table_name, exec_date=None, update_mode=None, script_n
                     
                     # 执行删除操作
                     delete_sql = f"""
-                        DELETE FROM {table_name}
+                        DELETE FROM {full_table_name}
                         WHERE {target_dt_column} >= '{start_date}'
                         AND {target_dt_column} < '{end_date}'
                     """
                     with target_engine.begin() as conn:  # 使用begin()自动管理事务
-                        conn.execute(delete_sql)
-                    logger.info(f"成功删除表 {table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
+                        conn.execute(text(delete_sql))
+                    logger.info(f"成功删除表 {full_table_name} 中 {target_dt_column} 从 {start_date} 到 {end_date} 的数据")
                 else:
                     # 自动调度
                     start_datetime, end_datetime = get_one_day_range(exec_date)
@@ -364,14 +402,14 @@ def load_data_from_source(table_name, exec_date=None, update_mode=None, script_n
                     
                     # 执行删除操作
                     delete_sql = f"""
-                        DELETE FROM {table_name}
+                        DELETE FROM {full_table_name}
                         WHERE create_time >= '{start_date}'
                         AND create_time < '{end_date}'
                     """
                     try:
                         with target_engine.begin() as conn:  # 使用begin()自动管理事务
-                            conn.execute(delete_sql)
-                        logger.info(f"成功删除表 {table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
+                            conn.execute(text(delete_sql))
+                        logger.info(f"成功删除表 {full_table_name} 中 create_time 从 {start_date} 到 {end_date} 的数据")
                     except Exception as del_err:
                         logger.error(f"删除数据时出错: {str(del_err)}")
                         logger.warning("继续执行数据加载")
@@ -383,19 +421,23 @@ def load_data_from_source(table_name, exec_date=None, update_mode=None, script_n
                 
                 # 检查源表是否含有目标日期列
                 source_inspector = inspect(source_engine)
-                source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
+                # 处理源表的schema信息用于检查列
+                if source_schema:
+                    source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table, schema=source_schema)]
+                else:
+                    source_columns = [col['name'].lower() for col in source_inspector.get_columns(source_table)]
                 
                 if target_dt_column.lower() in source_columns:
                     # 源表含有目标日期列,构建包含日期条件的查询
                     query = f"""
-                        SELECT * FROM {source_table}
+                        SELECT * FROM {full_source_table_name}
                         WHERE {target_dt_column} >= '{start_date}'
                         AND {target_dt_column} < '{end_date}'
                     """
                 else:
                     # 源表不含目标日期列,构建全量查询
-                    logger.warning(f"源表 {source_table} 没有目标日期列 {target_dt_column},将加载全部数据")
-                    query = f"SELECT * FROM {source_table}"
+                    logger.warning(f"源表 {full_source_table_name} 没有目标日期列 {target_dt_column},将加载全部数据")
+                    query = f"SELECT * FROM {full_source_table_name}"
                 
             except Exception as date_err:
                 logger.error(f"计算日期范围时出错: {str(date_err)}")
@@ -435,17 +477,26 @@ def load_data_from_source(table_name, exec_date=None, update_mode=None, script_n
                 logger.info(f"为数据添加 update_time 字段,初始值为: NULL (数据加载时不设置更新时间)")
             
             # 写入数据到目标表
-            logger.info(f"开始写入数据到目标表 {table_name},共 {len(df)} 行")
+            logger.info(f"开始写入数据到目标表 {full_table_name},共 {len(df)} 行")
             with target_engine.connect() as connection:
-                df.to_sql(
-                    name=table_name,
-                    con=connection,
-                    if_exists='append',
-                    index=False,
-                    schema=target_db_info.get("schema", "public")
-                )
+                # 处理schema参数,如果为空则不传递schema参数
+                if target_schema:
+                    df.to_sql(
+                        name=table_name,
+                        con=connection,
+                        if_exists='append',
+                        index=False,
+                        schema=target_schema
+                    )
+                else:
+                    df.to_sql(
+                        name=table_name,
+                        con=connection,
+                        if_exists='append',
+                        index=False
+                    )
             
-            logger.info(f"成功写入数据到目标表 {table_name}")
+            logger.info(f"成功写入数据到目标表 {full_table_name}")
             return True
             
         except Exception as query_err:

+ 20 - 14
dataops_scripts/load_file.py

@@ -67,8 +67,7 @@ def get_table_columns(table_name):
             FROM 
                 information_schema.columns
             WHERE 
-                table_schema = 'public' -- 明确指定 schema,如果需要
-                AND table_name = %s
+                table_name = %s
             ORDER BY 
                 ordinal_position
         """, (table_name.lower(),))
@@ -302,13 +301,20 @@ def load_dataframe_to_table(df, file_path, table_name):
         else:
             logger.info(f"数据行数: {final_row_count}")
 
-        # 检查目标表是否有create_time字段,如果有则添加当前时间
+        # 检查目标表是否有时间戳字段,优先使用create_time,其次使用created_at
+        current_time = datetime.now()
+        timestamp_field_added = False
+
         if 'create_time' in table_columns:
-            current_time = datetime.now()
             df_mapped['create_time'] = current_time
             logger.info(f"目标表有 create_time 字段,设置值为: {current_time}")
+            timestamp_field_added = True
+        elif 'created_at' in table_columns:
+            df_mapped['created_at'] = current_time
+            logger.info(f"目标表有 created_at 字段,设置值为: {current_time}")
+            timestamp_field_added = True
         else:
-            logger.warning(f"目标表 '{table_name}' 没有 create_time 字段,跳过添加时间戳")
+            logger.warning(f"目标表 '{table_name}' 没有 create_time 字段也没有 created_at 字段,跳过添加时间戳")
 
         # 连接数据库
         conn = get_pg_conn()
@@ -317,7 +323,7 @@ def load_dataframe_to_table(df, file_path, table_name):
         # 构建INSERT语句
         columns = ', '.join([f'"{col}"' for col in df_mapped.columns])
         placeholders = ', '.join(['%s'] * len(df_mapped.columns))
-        insert_sql = f'INSERT INTO public."{table_name}" ({columns}) VALUES ({placeholders})'
+        insert_sql = f'INSERT INTO "{table_name}" ({columns}) VALUES ({placeholders})'
         
         # 批量插入数据
         rows = [tuple(row) for row in df_mapped.values]
@@ -588,12 +594,12 @@ def run(table_name, update_mode='append', exec_date=None, target_type=None,
                 conn = get_pg_conn()
                 cursor = conn.cursor()
                 # 假设表在 public schema,并为表名加引号
-                logger.info(f"执行全量刷新,清空表 public.\"{table_name}\"")
-                cursor.execute(f'TRUNCATE TABLE public.\"{table_name}\"')
+                logger.info(f"执行全量刷新,清空表 \"{table_name}\"")
+                cursor.execute(f'TRUNCATE TABLE \"{table_name}\"')
                 conn.commit()
-                logger.info("表 public.\"" + table_name + "\" 已清空。")
+                logger.info("表 \"" + table_name + "\" 已清空。")
             except Exception as e:
-                logger.error("清空表 public.\"" + table_name + "\" 时出错: " + str(e))
+                logger.error("清空表 \"" + table_name + "\" 时出错: " + str(e))
                 if conn:
                     conn.rollback()
                 return False # 清空失败则直接失败退出
@@ -614,12 +620,12 @@ def run(table_name, update_mode='append', exec_date=None, target_type=None,
                 conn = get_pg_conn()
                 cursor = conn.cursor()
                 # 假设表在 public schema,并为表名加引号
-                logger.info(f"执行全量刷新,清空表 public.\"{table_name}\"")
-                cursor.execute(f'TRUNCATE TABLE public.\"{table_name}\"')
+                logger.info(f"执行全量刷新,清空表 \"{table_name}\"")
+                cursor.execute(f'TRUNCATE TABLE \"{table_name}\"')
                 conn.commit()
-                logger.info("表 public.\"" + table_name + "\" 已清空。")
+                logger.info("表 \"" + table_name + "\" 已清空。")
             except Exception as e:
-                logger.error("清空表 public.\"" + table_name + "\" 时出错: " + str(e))
+                logger.error("清空表 \"" + table_name + "\" 时出错: " + str(e))
                 if conn:
                     conn.rollback()
                 return False # 清空失败则直接失败退出