| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612 |
- """
- 数据资源导入工具
- 功能:从远程数据源读取数据,按照指定的更新模式写入到目标数据资源表中
- 支持:
- - 灵活的数据源配置(PostgreSQL/MySQL等)
- - 灵活的目标表配置
- - 两种更新模式:append(追加)/ full(全量更新)
- 作者:cursor
- 创建时间:2025-11-28
- 更新时间:2025-11-28
- """
- import argparse
- import json
- import logging
- import os
- import sys
- from typing import Any, Dict, List, Optional
- import psycopg2
- from sqlalchemy import create_engine, inspect, text
- from sqlalchemy.engine import Engine
- from sqlalchemy.orm import Session, sessionmaker
- # 添加项目根目录到路径
- sys.path.insert(
- 0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
- )
- try:
- from app.config.config import config, get_environment # type: ignore
- # 获取当前环境的配置类
- _current_env = get_environment()
- Config = config.get(_current_env, config["default"])
- except ImportError:
- # 如果无法导入,使用环境变量
- class Config: # type: ignore
- SQLALCHEMY_DATABASE_URI = os.environ.get(
- "DATABASE_URI", "postgresql://user:password@localhost:5432/database"
- )
- try:
- import pymysql # type: ignore
- MYSQL_AVAILABLE = True
- except ImportError:
- MYSQL_AVAILABLE = False
- pymysql = None # type: ignore
- # 配置日志
- logging.basicConfig(
- level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
- )
- logger = logging.getLogger(__name__)
- class ResourceDataImporter:
- """数据资源导入器"""
- # 目标表所在的 schema
- TARGET_SCHEMA = "dags"
- def __init__(
- self,
- source_config: Dict[str, Any],
- target_table_name: str,
- update_mode: str = "append",
- ):
- """
- 初始化导入器
- Args:
- source_config: 源数据库配置
- {
- 'type': 'postgresql', # 或 'mysql'
- 'host': '10.52.31.104',
- 'port': 5432,
- 'database': 'source_db',
- 'username': 'user',
- 'password': 'password',
- 'table_name': 'TB_JC_KSDZB' # 源表名
- }
- target_table_name: 目标表名(数据资源的英文名)
- update_mode: 更新模式,'append'(追加)或 'full'(全量更新)
- """
- self.source_config = source_config
- self.target_table_name = target_table_name
- self.update_mode = update_mode.lower()
- self.source_connection: Optional[Any] = None
- self.target_engine: Optional[Engine] = None
- self.target_session: Optional[Session] = None
- self.imported_count = 0
- self.updated_count = 0
- self.error_count = 0
- # 验证更新模式
- if self.update_mode not in ["append", "full"]:
- raise ValueError(
- f"不支持的更新模式: {update_mode},仅支持 'append' 或 'full'"
- )
- logger.info(
- f"初始化数据导入器: 目标表={self.TARGET_SCHEMA}.{target_table_name}, 更新模式={update_mode}"
- )
- def connect_target_database(self) -> bool:
- """
- 连接目标数据库(从 config.py 获取配置)
- Returns:
- 连接是否成功
- """
- try:
- # 从 Config 获取 PostgreSQL 配置
- db_uri = Config.SQLALCHEMY_DATABASE_URI
- if not db_uri:
- logger.error("未找到目标数据库配置(SQLALCHEMY_DATABASE_URI)")
- return False
- # 创建目标数据库引擎
- self.target_engine = create_engine(db_uri)
- Session = sessionmaker(bind=self.target_engine)
- self.target_session = Session()
- # 测试连接
- self.target_engine.connect()
- logger.info(f"成功连接目标数据库: {db_uri.split('@')[-1]}") # 隐藏密码
- return True
- except Exception as e:
- logger.error(f"连接目标数据库失败: {str(e)}")
- return False
- def connect_source_database(self) -> bool:
- """
- 连接源数据库
- Returns:
- 连接是否成功
- """
- try:
- db_type = self.source_config["type"].lower()
- if db_type == "postgresql":
- self.source_connection = psycopg2.connect(
- host=self.source_config["host"],
- port=self.source_config["port"],
- database=self.source_config["database"],
- user=self.source_config["username"],
- password=self.source_config["password"],
- )
- logger.info(
- f"成功连接源数据库(PostgreSQL): {self.source_config['host']}:{self.source_config['port']}/{self.source_config['database']}"
- )
- return True
- elif db_type == "mysql":
- if not MYSQL_AVAILABLE or pymysql is None:
- logger.error("pymysql未安装,无法连接MySQL数据库")
- return False
- self.source_connection = pymysql.connect(
- host=self.source_config["host"],
- port=self.source_config["port"],
- database=self.source_config["database"],
- user=self.source_config["username"],
- password=self.source_config["password"],
- )
- logger.info(
- f"成功连接源数据库(MySQL): {self.source_config['host']}:{self.source_config['port']}/{self.source_config['database']}"
- )
- return True
- else:
- logger.error(f"不支持的数据库类型: {db_type}")
- return False
- except Exception as e:
- logger.error(f"连接源数据库失败: {str(e)}")
- return False
- def get_full_table_name(self) -> str:
- """
- 获取带 schema 的完整表名
- Returns:
- 完整表名 (schema.table_name)
- """
- return f"{self.TARGET_SCHEMA}.{self.target_table_name}"
- def get_target_table_columns(self) -> List[str]:
- """
- 获取目标表的列名
- Returns:
- 列名列表
- """
- try:
- if not self.target_engine:
- logger.error("目标数据库引擎未初始化")
- return []
- inspector = inspect(self.target_engine)
- # 指定 schema 来获取表的列名
- columns = inspector.get_columns(
- self.target_table_name, schema=self.TARGET_SCHEMA
- )
- column_names = [
- col["name"] for col in columns if col["name"] != "create_time"
- ]
- logger.info(f"目标表 {self.get_full_table_name()} 的列: {column_names}")
- return column_names
- except Exception as e:
- logger.error(f"获取目标表列名失败: {str(e)}")
- return []
- def extract_source_data(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
- """
- 从源数据库提取数据
- Args:
- limit: 限制提取的数据行数(None 表示不限制)
- Returns:
- 数据行列表
- """
- try:
- if not self.source_connection:
- logger.error("源数据库连接未建立")
- return []
- cursor = self.source_connection.cursor()
- source_table = self.source_config.get("table_name")
- if not source_table:
- logger.error("源表名未指定")
- return []
- # 构建查询语句
- query = f"SELECT * FROM {source_table}"
- # 添加过滤条件(如果有)
- where_clause = self.source_config.get("where_clause", "")
- if where_clause:
- query += f" WHERE {where_clause}"
- # 添加排序(如果有)
- order_by = self.source_config.get("order_by", "")
- if order_by:
- query += f" ORDER BY {order_by}"
- # 添加限制
- if limit:
- query += f" LIMIT {limit}"
- logger.info(f"执行查询: {query}")
- cursor.execute(query)
- # 获取列名
- columns = [desc[0] for desc in cursor.description]
- # 提取数据
- rows = []
- for row in cursor.fetchall():
- row_dict = dict(zip(columns, row))
- rows.append(row_dict)
- cursor.close()
- logger.info(f"从源表 {source_table} 提取了 {len(rows)} 条数据")
- return rows
- except Exception as e:
- logger.error(f"提取源数据失败: {str(e)}")
- return []
- def clear_target_table(self) -> bool:
- """
- 清空目标表(用于全量更新模式)
- Returns:
- 清空是否成功
- """
- try:
- if not self.target_session:
- logger.error("目标数据库会话未初始化")
- return False
- full_table_name = self.get_full_table_name()
- delete_sql = text(f"DELETE FROM {full_table_name}")
- self.target_session.execute(delete_sql)
- self.target_session.commit()
- logger.info(f"目标表 {full_table_name} 已清空")
- return True
- except Exception as e:
- if self.target_session:
- self.target_session.rollback()
- logger.error(f"清空目标表失败: {str(e)}")
- return False
- def map_source_to_target_columns(
- self, source_row: Dict[str, Any], target_columns: List[str]
- ) -> Dict[str, Any]:
- """
- 将源数据列映射到目标表列
- Args:
- source_row: 源数据行
- target_columns: 目标表列名列表
- Returns:
- 映射后的数据行
- """
- mapped_row = {}
- for col in target_columns:
- # 优先使用精确匹配(不区分大小写)
- col_lower = col.lower()
- for source_col, value in source_row.items():
- if source_col.lower() == col_lower:
- mapped_row[col] = value
- break
- else:
- # 如果没有匹配到,设置为 None
- mapped_row[col] = None
- return mapped_row
- def insert_data_to_target(self, data_rows: List[Dict[str, Any]]) -> bool:
- """
- 将数据插入目标表
- Args:
- data_rows: 数据行列表
- Returns:
- 插入是否成功
- """
- try:
- if not data_rows:
- logger.warning("没有数据需要插入")
- return True
- if not self.target_session:
- logger.error("目标数据库会话未初始化")
- return False
- # 获取目标表列名
- target_columns = self.get_target_table_columns()
- if not target_columns:
- logger.error("无法获取目标表列名")
- return False
- # 全量更新模式:先清空目标表
- if self.update_mode == "full" and not self.clear_target_table():
- return False
- # 构建插入 SQL(使用带 schema 的完整表名)
- full_table_name = self.get_full_table_name()
- columns_str = ", ".join(target_columns + ["create_time"])
- placeholders = ", ".join(
- [f":{col}" for col in target_columns] + ["CURRENT_TIMESTAMP"]
- )
- insert_sql = text(f"""
- INSERT INTO {full_table_name} ({columns_str})
- VALUES ({placeholders})
- """)
- # 批量插入
- success_count = 0
- for source_row in data_rows:
- try:
- # 映射列名
- mapped_row = self.map_source_to_target_columns(
- source_row, target_columns
- )
- # 执行插入
- self.target_session.execute(insert_sql, mapped_row)
- success_count += 1
- # 每 100 条提交一次
- if success_count % 100 == 0:
- self.target_session.commit()
- logger.info(f"已插入 {success_count} 条数据...")
- except Exception as e:
- self.error_count += 1
- logger.error(f"插入数据失败: {str(e)}, 数据: {source_row}")
- # 最终提交
- self.target_session.commit()
- self.imported_count = success_count
- logger.info(
- f"数据插入完成: 成功 {self.imported_count} 条, 失败 {self.error_count} 条"
- )
- return True
- except Exception as e:
- if self.target_session:
- self.target_session.rollback()
- logger.error(f"批量插入数据失败: {str(e)}")
- return False
- def close_connections(self):
- """关闭所有数据库连接"""
- # 关闭源数据库连接
- if self.source_connection:
- try:
- self.source_connection.close()
- logger.info("源数据库连接已关闭")
- except Exception as e:
- logger.error(f"关闭源数据库连接失败: {str(e)}")
- # 关闭目标数据库连接
- if self.target_session:
- try:
- self.target_session.close()
- logger.info("目标数据库会话已关闭")
- except Exception as e:
- logger.error(f"关闭目标数据库会话失败: {str(e)}")
- if self.target_engine:
- try:
- self.target_engine.dispose()
- logger.info("目标数据库引擎已释放")
- except Exception as e:
- logger.error(f"释放目标数据库引擎失败: {str(e)}")
- def run(self, limit: Optional[int] = None) -> Dict[str, Any]:
- """
- 执行导入流程
- Args:
- limit: 限制导入的数据行数(None 表示不限制)
- Returns:
- 执行结果
- """
- result = {
- "success": False,
- "imported_count": 0,
- "error_count": 0,
- "update_mode": self.update_mode,
- "message": "",
- }
- try:
- logger.info("=" * 60)
- logger.info("开始数据导入")
- logger.info(f"源表: {self.source_config.get('table_name')}")
- logger.info(f"目标表: {self.get_full_table_name()}")
- logger.info(f"更新模式: {self.update_mode}")
- logger.info("=" * 60)
- # 1. 连接源数据库
- if not self.connect_source_database():
- result["message"] = "连接源数据库失败"
- return result
- # 2. 连接目标数据库
- if not self.connect_target_database():
- result["message"] = "连接目标数据库失败"
- return result
- # 3. 提取源数据
- data_rows = self.extract_source_data(limit=limit)
- if not data_rows:
- result["message"] = "未提取到数据"
- result["success"] = True # 没有数据不算失败
- return result
- # 4. 插入数据到目标表
- if self.insert_data_to_target(data_rows):
- result["success"] = True
- result["imported_count"] = self.imported_count
- result["error_count"] = self.error_count
- result["message"] = (
- f"导入完成: 成功 {self.imported_count} 条, 失败 {self.error_count} 条"
- )
- else:
- result["message"] = "插入数据到目标表失败"
- except Exception as e:
- logger.error(f"导入过程发生异常: {str(e)}")
- result["message"] = f"导入失败: {str(e)}"
- finally:
- # 5. 关闭连接
- self.close_connections()
- logger.info("=" * 60)
- logger.info(f"导入结果: {result['message']}")
- logger.info("=" * 60)
- return result
- def import_resource_data(
- source_config: Dict[str, Any],
- target_table_name: str,
- update_mode: str = "append",
- limit: Optional[int] = None,
- ) -> Dict[str, Any]:
- """
- 导入数据资源(入口函数)
- Args:
- source_config: 源数据库配置
- {
- 'type': 'postgresql', # 或 'mysql'
- 'host': '10.52.31.104',
- 'port': 5432,
- 'database': 'source_db',
- 'username': 'user',
- 'password': 'password',
- 'table_name': 'TB_JC_KSDZB', # 源表名
- 'where_clause': "TBRQ >= '2025-01-01'", # 可选:WHERE条件
- 'order_by': 'TBRQ DESC' # 可选:排序
- }
- target_table_name: 目标表名(数据资源的英文名)
- update_mode: 更新模式,'append'(追加)或 'full'(全量更新)
- limit: 限制导入的数据行数(None 表示不限制)
- Returns:
- 导入结果
- """
- importer = ResourceDataImporter(
- source_config=source_config,
- target_table_name=target_table_name,
- update_mode=update_mode,
- )
- return importer.run(limit=limit)
- def parse_args():
- """解析命令行参数"""
- parser = argparse.ArgumentParser(description="数据资源导入工具")
- parser.add_argument(
- "--source-config",
- type=str,
- required=True,
- help="源数据库配置(JSON格式字符串或文件路径)",
- )
- parser.add_argument(
- "--target-table", type=str, required=True, help="目标表名(数据资源的英文名)"
- )
- parser.add_argument(
- "--update-mode",
- type=str,
- choices=["append", "full"],
- default="append",
- help="更新模式:append(追加)或 full(全量更新)",
- )
- parser.add_argument("--limit", type=int, default=None, help="限制导入的数据行数")
- return parser.parse_args()
- if __name__ == "__main__":
- # 解析命令行参数
- args = parse_args()
- # 解析源数据库配置
- try:
- # 尝试作为JSON字符串解析
- source_config = json.loads(args.source_config)
- except json.JSONDecodeError:
- # 尝试作为文件路径读取
- try:
- with open(args.source_config, encoding="utf-8") as f:
- source_config = json.load(f)
- except Exception as e:
- logger.error(f"解析源数据库配置失败: {str(e)}")
- exit(1)
- # 执行导入
- result = import_resource_data(
- source_config=source_config,
- target_table_name=args.target_table,
- update_mode=args.update_mode,
- limit=args.limit,
- )
- # 输出结果
- print("\n" + "=" * 60)
- print(f"导入结果: {'成功' if result['success'] else '失败'}")
- print(f"消息: {result['message']}")
- print(f"成功: {result['imported_count']} 条")
- print(f"失败: {result['error_count']} 条")
- print(f"更新模式: {result['update_mode']}")
- print("=" * 60)
- # 设置退出代码
- exit(0 if result["success"] else 1)
|