""" 数据资源导入工具 功能:从远程数据源读取数据,按照指定的更新模式写入到目标数据资源表中 支持: - 灵活的数据源配置(PostgreSQL/MySQL等) - 灵活的目标表配置 - 两种更新模式:append(追加)/ full(全量更新) 作者:cursor 创建时间:2025-11-28 更新时间:2025-11-28 """ import argparse import json import logging import os import sys from typing import Any, Dict, List, Optional import psycopg2 from sqlalchemy import create_engine, inspect, text from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker # 添加项目根目录到路径 sys.path.insert( 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) ) try: from app.config.config import config, get_environment # type: ignore # 获取当前环境的配置类 _current_env = get_environment() Config = config.get(_current_env, config["default"]) except ImportError: # 如果无法导入,使用环境变量 class Config: # type: ignore SQLALCHEMY_DATABASE_URI = os.environ.get( "DATABASE_URI", "postgresql://user:password@localhost:5432/database" ) try: import pymysql # type: ignore MYSQL_AVAILABLE = True except ImportError: MYSQL_AVAILABLE = False pymysql = None # type: ignore # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class ResourceDataImporter: """数据资源导入器""" # 目标表所在的 schema TARGET_SCHEMA = "dags" def __init__( self, source_config: Dict[str, Any], target_table_name: str, update_mode: str = "append", ): """ 初始化导入器 Args: source_config: 源数据库配置 { 'type': 'postgresql', # 或 'mysql' 'host': '10.52.31.104', 'port': 5432, 'database': 'source_db', 'username': 'user', 'password': 'password', 'table_name': 'TB_JC_KSDZB' # 源表名 } target_table_name: 目标表名(数据资源的英文名) update_mode: 更新模式,'append'(追加)或 'full'(全量更新) """ self.source_config = source_config self.target_table_name = target_table_name self.update_mode = update_mode.lower() self.source_connection: Optional[Any] = None self.target_engine: Optional[Engine] = None self.target_session: Optional[Session] = None self.imported_count = 0 self.updated_count = 0 self.error_count = 0 # 验证更新模式 if self.update_mode not in ["append", "full"]: raise ValueError( f"不支持的更新模式: {update_mode},仅支持 'append' 或 'full'" ) logger.info( f"初始化数据导入器: 目标表={self.TARGET_SCHEMA}.{target_table_name}, 更新模式={update_mode}" ) def connect_target_database(self) -> bool: """ 连接目标数据库(从 config.py 获取配置) Returns: 连接是否成功 """ try: # 从 Config 获取 PostgreSQL 配置 db_uri = Config.SQLALCHEMY_DATABASE_URI if not db_uri: logger.error("未找到目标数据库配置(SQLALCHEMY_DATABASE_URI)") return False # 创建目标数据库引擎 self.target_engine = create_engine(db_uri) Session = sessionmaker(bind=self.target_engine) self.target_session = Session() # 测试连接 self.target_engine.connect() logger.info(f"成功连接目标数据库: {db_uri.split('@')[-1]}") # 隐藏密码 return True except Exception as e: logger.error(f"连接目标数据库失败: {str(e)}") return False def connect_source_database(self) -> bool: """ 连接源数据库 Returns: 连接是否成功 """ try: db_type = self.source_config["type"].lower() if db_type == "postgresql": self.source_connection = psycopg2.connect( host=self.source_config["host"], port=self.source_config["port"], database=self.source_config["database"], user=self.source_config["username"], password=self.source_config["password"], ) logger.info( f"成功连接源数据库(PostgreSQL): {self.source_config['host']}:{self.source_config['port']}/{self.source_config['database']}" ) return True elif db_type == "mysql": if not MYSQL_AVAILABLE or pymysql is None: logger.error("pymysql未安装,无法连接MySQL数据库") return False self.source_connection = pymysql.connect( host=self.source_config["host"], port=self.source_config["port"], database=self.source_config["database"], user=self.source_config["username"], password=self.source_config["password"], ) logger.info( f"成功连接源数据库(MySQL): {self.source_config['host']}:{self.source_config['port']}/{self.source_config['database']}" ) return True else: logger.error(f"不支持的数据库类型: {db_type}") return False except Exception as e: logger.error(f"连接源数据库失败: {str(e)}") return False def get_full_table_name(self) -> str: """ 获取带 schema 的完整表名 Returns: 完整表名 (schema.table_name) """ return f"{self.TARGET_SCHEMA}.{self.target_table_name}" def get_target_table_columns(self) -> List[str]: """ 获取目标表的列名 Returns: 列名列表 """ try: if not self.target_engine: logger.error("目标数据库引擎未初始化") return [] inspector = inspect(self.target_engine) # 指定 schema 来获取表的列名 columns = inspector.get_columns( self.target_table_name, schema=self.TARGET_SCHEMA ) column_names = [ col["name"] for col in columns if col["name"] != "create_time" ] logger.info(f"目标表 {self.get_full_table_name()} 的列: {column_names}") return column_names except Exception as e: logger.error(f"获取目标表列名失败: {str(e)}") return [] def extract_source_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: """ 从源数据库提取数据 Args: limit: 限制提取的数据行数(None 表示不限制) Returns: 数据行列表 """ try: if not self.source_connection: logger.error("源数据库连接未建立") return [] cursor = self.source_connection.cursor() source_table = self.source_config.get("table_name") if not source_table: logger.error("源表名未指定") return [] # 构建查询语句 query = f"SELECT * FROM {source_table}" # 添加过滤条件(如果有) where_clause = self.source_config.get("where_clause", "") if where_clause: query += f" WHERE {where_clause}" # 添加排序(如果有) order_by = self.source_config.get("order_by", "") if order_by: query += f" ORDER BY {order_by}" # 添加限制 if limit: query += f" LIMIT {limit}" logger.info(f"执行查询: {query}") cursor.execute(query) # 获取列名 columns = [desc[0] for desc in cursor.description] # 提取数据 rows = [] for row in cursor.fetchall(): row_dict = dict(zip(columns, row)) rows.append(row_dict) cursor.close() logger.info(f"从源表 {source_table} 提取了 {len(rows)} 条数据") return rows except Exception as e: logger.error(f"提取源数据失败: {str(e)}") return [] def clear_target_table(self) -> bool: """ 清空目标表(用于全量更新模式) Returns: 清空是否成功 """ try: if not self.target_session: logger.error("目标数据库会话未初始化") return False full_table_name = self.get_full_table_name() delete_sql = text(f"DELETE FROM {full_table_name}") self.target_session.execute(delete_sql) self.target_session.commit() logger.info(f"目标表 {full_table_name} 已清空") return True except Exception as e: if self.target_session: self.target_session.rollback() logger.error(f"清空目标表失败: {str(e)}") return False def map_source_to_target_columns( self, source_row: Dict[str, Any], target_columns: List[str] ) -> Dict[str, Any]: """ 将源数据列映射到目标表列 Args: source_row: 源数据行 target_columns: 目标表列名列表 Returns: 映射后的数据行 """ mapped_row = {} for col in target_columns: # 优先使用精确匹配(不区分大小写) col_lower = col.lower() for source_col, value in source_row.items(): if source_col.lower() == col_lower: mapped_row[col] = value break else: # 如果没有匹配到,设置为 None mapped_row[col] = None return mapped_row def insert_data_to_target(self, data_rows: List[Dict[str, Any]]) -> bool: """ 将数据插入目标表 Args: data_rows: 数据行列表 Returns: 插入是否成功 """ try: if not data_rows: logger.warning("没有数据需要插入") return True if not self.target_session: logger.error("目标数据库会话未初始化") return False # 获取目标表列名 target_columns = self.get_target_table_columns() if not target_columns: logger.error("无法获取目标表列名") return False # 全量更新模式:先清空目标表 if self.update_mode == "full" and not self.clear_target_table(): return False # 构建插入 SQL(使用带 schema 的完整表名) full_table_name = self.get_full_table_name() columns_str = ", ".join(target_columns + ["create_time"]) placeholders = ", ".join( [f":{col}" for col in target_columns] + ["CURRENT_TIMESTAMP"] ) insert_sql = text(f""" INSERT INTO {full_table_name} ({columns_str}) VALUES ({placeholders}) """) # 批量插入 success_count = 0 for source_row in data_rows: try: # 映射列名 mapped_row = self.map_source_to_target_columns( source_row, target_columns ) # 执行插入 self.target_session.execute(insert_sql, mapped_row) success_count += 1 # 每 100 条提交一次 if success_count % 100 == 0: self.target_session.commit() logger.info(f"已插入 {success_count} 条数据...") except Exception as e: self.error_count += 1 logger.error(f"插入数据失败: {str(e)}, 数据: {source_row}") # 最终提交 self.target_session.commit() self.imported_count = success_count logger.info( f"数据插入完成: 成功 {self.imported_count} 条, 失败 {self.error_count} 条" ) return True except Exception as e: if self.target_session: self.target_session.rollback() logger.error(f"批量插入数据失败: {str(e)}") return False def close_connections(self): """关闭所有数据库连接""" # 关闭源数据库连接 if self.source_connection: try: self.source_connection.close() logger.info("源数据库连接已关闭") except Exception as e: logger.error(f"关闭源数据库连接失败: {str(e)}") # 关闭目标数据库连接 if self.target_session: try: self.target_session.close() logger.info("目标数据库会话已关闭") except Exception as e: logger.error(f"关闭目标数据库会话失败: {str(e)}") if self.target_engine: try: self.target_engine.dispose() logger.info("目标数据库引擎已释放") except Exception as e: logger.error(f"释放目标数据库引擎失败: {str(e)}") def run(self, limit: Optional[int] = None) -> Dict[str, Any]: """ 执行导入流程 Args: limit: 限制导入的数据行数(None 表示不限制) Returns: 执行结果 """ result = { "success": False, "imported_count": 0, "error_count": 0, "update_mode": self.update_mode, "message": "", } try: logger.info("=" * 60) logger.info("开始数据导入") logger.info(f"源表: {self.source_config.get('table_name')}") logger.info(f"目标表: {self.get_full_table_name()}") logger.info(f"更新模式: {self.update_mode}") logger.info("=" * 60) # 1. 连接源数据库 if not self.connect_source_database(): result["message"] = "连接源数据库失败" return result # 2. 连接目标数据库 if not self.connect_target_database(): result["message"] = "连接目标数据库失败" return result # 3. 提取源数据 data_rows = self.extract_source_data(limit=limit) if not data_rows: result["message"] = "未提取到数据" result["success"] = True # 没有数据不算失败 return result # 4. 插入数据到目标表 if self.insert_data_to_target(data_rows): result["success"] = True result["imported_count"] = self.imported_count result["error_count"] = self.error_count result["message"] = ( f"导入完成: 成功 {self.imported_count} 条, 失败 {self.error_count} 条" ) else: result["message"] = "插入数据到目标表失败" except Exception as e: logger.error(f"导入过程发生异常: {str(e)}") result["message"] = f"导入失败: {str(e)}" finally: # 5. 关闭连接 self.close_connections() logger.info("=" * 60) logger.info(f"导入结果: {result['message']}") logger.info("=" * 60) return result def import_resource_data( source_config: Dict[str, Any], target_table_name: str, update_mode: str = "append", limit: Optional[int] = None, ) -> Dict[str, Any]: """ 导入数据资源(入口函数) Args: source_config: 源数据库配置 { 'type': 'postgresql', # 或 'mysql' 'host': '10.52.31.104', 'port': 5432, 'database': 'source_db', 'username': 'user', 'password': 'password', 'table_name': 'TB_JC_KSDZB', # 源表名 'where_clause': "TBRQ >= '2025-01-01'", # 可选:WHERE条件 'order_by': 'TBRQ DESC' # 可选:排序 } target_table_name: 目标表名(数据资源的英文名) update_mode: 更新模式,'append'(追加)或 'full'(全量更新) limit: 限制导入的数据行数(None 表示不限制) Returns: 导入结果 """ importer = ResourceDataImporter( source_config=source_config, target_table_name=target_table_name, update_mode=update_mode, ) return importer.run(limit=limit) def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser(description="数据资源导入工具") parser.add_argument( "--source-config", type=str, required=True, help="源数据库配置(JSON格式字符串或文件路径)", ) parser.add_argument( "--target-table", type=str, required=True, help="目标表名(数据资源的英文名)" ) parser.add_argument( "--update-mode", type=str, choices=["append", "full"], default="append", help="更新模式:append(追加)或 full(全量更新)", ) parser.add_argument("--limit", type=int, default=None, help="限制导入的数据行数") return parser.parse_args() if __name__ == "__main__": # 解析命令行参数 args = parse_args() # 解析源数据库配置 try: # 尝试作为JSON字符串解析 source_config = json.loads(args.source_config) except json.JSONDecodeError: # 尝试作为文件路径读取 try: with open(args.source_config, encoding="utf-8") as f: source_config = json.load(f) except Exception as e: logger.error(f"解析源数据库配置失败: {str(e)}") exit(1) # 执行导入 result = import_resource_data( source_config=source_config, target_table_name=args.target_table, update_mode=args.update_mode, limit=args.limit, ) # 输出结果 print("\n" + "=" * 60) print(f"导入结果: {'成功' if result['success'] else '失败'}") print(f"消息: {result['message']}") print(f"成功: {result['imported_count']} 条") print(f"失败: {result['error_count']} 条") print(f"更新模式: {result['update_mode']}") print("=" * 60) # 设置退出代码 exit(0 if result["success"] else 1)