""" 系统用户认证模块 提供用户注册、登录验证等功能 """ import logging import base64 import time import uuid import psycopg2 from psycopg2 import pool from urllib.parse import urlparse, unquote from flask import current_app, request, jsonify from functools import wraps logger = logging.getLogger(__name__) # PostgreSQL连接池 pg_pool = None def get_pg_connection(): """ 获取PostgreSQL数据库连接 Returns: connection: PostgreSQL连接对象 """ global pg_pool if pg_pool is None: try: # 解析SQLAlchemy URI,处理包含特殊字符的密码 db_uri = current_app.config['SQLALCHEMY_DATABASE_URI'] # 尝试使用urlparse解析 uri = urlparse(db_uri) # 如果解析失败或密码包含特殊字符导致解析错误,使用手动解析 if uri.username is None or uri.password is None: # 手动解析URI: postgresql://username:password@host:port/database scheme_end = db_uri.find('://') if scheme_end == -1: raise ValueError("Invalid database URI format") auth_and_host = db_uri[scheme_end + 3:] # 跳过 '://' at_pos = auth_and_host.rfind('@') # 从右向左查找最后一个@ if at_pos == -1: raise ValueError("Invalid database URI: missing @ separator") auth_part = auth_and_host[:at_pos] host_part = auth_and_host[at_pos + 1:] # 解析用户名和密码(可能包含特殊字符) colon_pos = auth_part.find(':') if colon_pos == -1: username = unquote(auth_part) password = None else: username = unquote(auth_part[:colon_pos]) password = unquote(auth_part[colon_pos + 1:]) # 解析主机、端口和数据库 slash_pos = host_part.find('/') if slash_pos == -1: raise ValueError("Invalid database URI: missing database name") host_port = host_part[:slash_pos] database = unquote(host_part[slash_pos + 1:]) # 解析主机和端口 colon_pos = host_port.find(':') if colon_pos == -1: hostname = host_port port = 5432 else: hostname = host_port[:colon_pos] port = int(host_port[colon_pos + 1:]) else: # urlparse解析成功,解码可能被URL编码的字段 username = unquote(uri.username) if uri.username else None password = unquote(uri.password) if uri.password else None database = unquote(uri.path[1:]) if uri.path and len(uri.path) > 1 else None hostname = uri.hostname port = uri.port or 5432 if not all([username, password, database, hostname]): raise ValueError("Missing required database connection parameters") # 创建连接池 pg_pool = psycopg2.pool.SimpleConnectionPool( 1, 20, host=hostname, database=database, user=username, password=password, port=str(port) ) logger.info("PostgreSQL连接池初始化成功") except Exception as e: logger.error(f"PostgreSQL连接池初始化失败: {str(e)}") raise return pg_pool.getconn() def release_pg_connection(conn): """ 释放PostgreSQL连接到连接池 Args: conn: 数据库连接对象 """ global pg_pool if pg_pool and conn: pg_pool.putconn(conn) def encode_password(password): """ 对密码进行base64编码 Args: password: 原始密码 Returns: str: 编码后的密码 """ return base64.b64encode(password.encode('utf-8')).decode('utf-8') def create_user_table(): """ 创建用户表,如果不存在 Returns: bool: 是否成功创建 """ conn = None try: conn = get_pg_connection() cursor = conn.cursor() # 创建用户表 create_table_query = """ CREATE TABLE IF NOT EXISTS users ( id VARCHAR(100) PRIMARY KEY, username VARCHAR(50) UNIQUE NOT NULL, password VARCHAR(100) NOT NULL, created_at FLOAT NOT NULL, last_login FLOAT, is_admin BOOLEAN DEFAULT FALSE ); """ cursor.execute(create_table_query) # 创建索引加速查询 create_index_query = """ CREATE INDEX IF NOT EXISTS idx_users_username ON users(username); """ cursor.execute(create_index_query) conn.commit() cursor.close() logger.info("用户表创建成功") return True except Exception as e: logger.error(f"创建用户表失败: {str(e)}") if conn: conn.rollback() return False finally: if conn: release_pg_connection(conn) def register_user(username, password): """ 注册新用户 Args: username: 用户名 password: 密码 Returns: tuple: (是否成功, 消息) """ conn = None try: # 确保表已创建 create_user_table() # 对密码进行编码 encoded_password = encode_password(password) # 生成用户ID user_id = str(uuid.uuid4()) conn = get_pg_connection() cursor = conn.cursor() # 检查用户名是否存在 check_query = "SELECT username FROM users WHERE username = %s" cursor.execute(check_query, (username,)) if cursor.fetchone(): return False, "用户名已存在" # 创建用户 insert_query = """ INSERT INTO users (id, username, password, created_at, last_login) VALUES (%s, %s, %s, %s, %s) """ cursor.execute( insert_query, (user_id, username, encoded_password, time.time(), None) ) conn.commit() cursor.close() return True, "注册成功" except Exception as e: logger.error(f"用户注册失败: {str(e)}") if conn: conn.rollback() return False, f"注册失败: {str(e)}" finally: if conn: release_pg_connection(conn) def login_user(username, password): """ 用户登录验证 Args: username: 用户名 password: 密码 Returns: tuple: (是否成功, 用户信息/错误消息) """ conn = None try: # 对输入的密码进行编码 encoded_password = encode_password(password) conn = get_pg_connection() cursor = conn.cursor() # 查询用户 query = """ SELECT id, username, password, created_at, last_login, is_admin FROM users WHERE username = %s """ cursor.execute(query, (username,)) user = cursor.fetchone() # 检查用户是否存在 if not user: return False, "用户名或密码错误" # 验证密码 if user[2] != encoded_password: return False, "用户名或密码错误" # 更新最后登录时间 current_time = time.time() update_query = """ UPDATE users SET last_login = %s WHERE username = %s """ cursor.execute(update_query, (current_time, username)) conn.commit() # 构建用户信息 user_info = { "id": user[0], "username": user[1], "created_at": user[3], "last_login": current_time, "is_admin": user[5] if len(user) > 5 else False } cursor.close() return True, user_info except Exception as e: logger.error(f"用户登录失败: {str(e)}") if conn: conn.rollback() return False, f"登录失败: {str(e)}" finally: if conn: release_pg_connection(conn) def get_user_by_username(username): """ 根据用户名获取用户信息 Args: username: 用户名 Returns: dict: 用户信息(不包含密码) """ conn = None try: conn = get_pg_connection() cursor = conn.cursor() query = """ SELECT id, username, created_at, last_login, is_admin FROM users WHERE username = %s """ cursor.execute(query, (username,)) user = cursor.fetchone() cursor.close() if not user: return None user_info = { "id": user[0], "username": user[1], "created_at": user[2], "last_login": user[3], "is_admin": user[4] if user[4] is not None else False } return user_info except Exception as e: logger.error(f"获取用户信息失败: {str(e)}") return None finally: if conn: release_pg_connection(conn) def init_db(): """ 初始化数据库,创建用户表 Returns: bool: 是否成功初始化 """ return create_user_table() def require_auth(f): @wraps(f) def decorated(*args, **kwargs): auth_header = request.headers.get('Authorization') if not auth_header: return jsonify({'message': '缺少认证头'}), 401 try: # 验证认证头 if auth_header != current_app.config['SECRET_KEY']: return jsonify({'message': '无效的认证信息'}), 401 return f(*args, **kwargs) except Exception as e: return jsonify({'message': '认证失败'}), 401 return decorated