123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 微信用户表迁移脚本
- 创建微信用户表和相关索引
- """
- import os
- import sys
- import logging
- import psycopg2
- from psycopg2 import sql
- # 添加项目根目录到Python路径
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
- from app.config.config import config, current_env
- # 获取配置
- app_config = config[current_env]
- # 配置日志
- log_level_name = getattr(app_config, 'LOG_LEVEL', 'INFO')
- log_level = getattr(logging, log_level_name)
- log_format = getattr(app_config, 'LOG_FORMAT', '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logging.basicConfig(
- level=log_level,
- format=log_format
- )
- logger = logging.getLogger(__name__)
- def get_database_connection():
- """
- 获取数据库连接
-
- Returns:
- psycopg2.connection: 数据库连接对象
- """
- try:
- # 从配置中获取数据库连接信息
- db_config = {
- 'host': app_config.PG_HOST,
- 'port': app_config.PG_PORT,
- 'database': app_config.PG_DATABASE,
- 'user': app_config.PG_USERNAME,
- 'password': app_config.PG_PASSWORD
- }
-
- connection = psycopg2.connect(**db_config)
- logger.info("成功连接到数据库")
- return connection
-
- except Exception as e:
- logger.error(f"连接数据库失败: {str(e)}")
- raise
- def check_table_exists(connection, table_name, schema='public'):
- """
- 检查表是否存在
-
- Args:
- connection: 数据库连接
- table_name (str): 表名
- schema (str): 模式名,默认为public
-
- Returns:
- bool: 表存在返回True,否则返回False
- """
- try:
- with connection.cursor() as cursor:
- cursor.execute("""
- SELECT EXISTS (
- SELECT FROM information_schema.tables
- WHERE table_schema = %s AND table_name = %s
- );
- """, (schema, table_name))
-
- result = cursor.fetchone()
- return result[0] if result else False
-
- except Exception as e:
- logger.error(f"检查表是否存在时发生错误: {str(e)}")
- return False
- def create_wechat_users_table(connection):
- """
- 创建微信用户表
-
- Args:
- connection: 数据库连接
-
- Returns:
- bool: 创建成功返回True,否则返回False
- """
- try:
- # 读取SQL DDL文件
- sql_file_path = os.path.join(os.path.dirname(__file__), '../../database/create_wechat_users.sql')
-
- if not os.path.exists(sql_file_path):
- logger.error(f"SQL文件不存在: {sql_file_path}")
- return False
-
- with open(sql_file_path, 'r', encoding='utf-8') as file:
- sql_content = file.read()
-
- with connection.cursor() as cursor:
- # 执行SQL脚本
- cursor.execute(sql_content)
- connection.commit()
-
- logger.info("微信用户表创建成功")
- return True
-
- except Exception as e:
- logger.error(f"创建微信用户表失败: {str(e)}")
- connection.rollback()
- return False
- def migrate_wechat_users():
- """
- 执行微信用户表迁移
-
- Returns:
- bool: 迁移成功返回True,否则返回False
- """
- connection = None
-
- try:
- # 获取数据库连接
- connection = get_database_connection()
-
- # 检查表是否已存在
- if check_table_exists(connection, 'wechat_users'):
- logger.warning("微信用户表已存在,跳过创建")
- return True
-
- logger.info("开始创建微信用户表...")
-
- # 创建微信用户表
- if create_wechat_users_table(connection):
- logger.info("微信用户表迁移完成")
- return True
- else:
- logger.error("微信用户表迁移失败")
- return False
-
- except Exception as e:
- logger.error(f"迁移过程中发生错误: {str(e)}")
- return False
-
- finally:
- if connection:
- connection.close()
- logger.info("数据库连接已关闭")
- def rollback_wechat_users():
- """
- 回滚微信用户表迁移(删除表)
-
- Returns:
- bool: 回滚成功返回True,否则返回False
- """
- connection = None
-
- try:
- # 获取数据库连接
- connection = get_database_connection()
-
- # 检查表是否存在
- if not check_table_exists(connection, 'wechat_users'):
- logger.warning("微信用户表不存在,无需回滚")
- return True
-
- logger.info("开始回滚微信用户表...")
-
- with connection.cursor() as cursor:
- # 删除表
- cursor.execute("DROP TABLE IF EXISTS public.wechat_users CASCADE;")
- connection.commit()
-
- logger.info("微信用户表回滚完成")
- return True
-
- except Exception as e:
- logger.error(f"回滚过程中发生错误: {str(e)}")
- if connection:
- connection.rollback()
- return False
-
- finally:
- if connection:
- connection.close()
- logger.info("数据库连接已关闭")
- def main():
- """
- 主函数,根据命令行参数执行相应操作
- """
- import argparse
-
- parser = argparse.ArgumentParser(description='微信用户表迁移脚本')
- parser.add_argument('--action', choices=['migrate', 'rollback'], default='migrate',
- help='执行的操作:migrate(迁移)或 rollback(回滚)')
-
- args = parser.parse_args()
-
- if args.action == 'migrate':
- logger.info("开始执行微信用户表迁移...")
- success = migrate_wechat_users()
- elif args.action == 'rollback':
- logger.info("开始执行微信用户表回滚...")
- success = rollback_wechat_users()
- else:
- logger.error("未知的操作类型")
- sys.exit(1)
-
- if success:
- logger.info("操作完成")
- sys.exit(0)
- else:
- logger.error("操作失败")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|