auth.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """
  2. 系统用户认证模块
  3. 提供用户注册、登录验证等功能
  4. """
  5. import logging
  6. import base64
  7. import time
  8. import uuid
  9. import psycopg2
  10. from psycopg2 import pool
  11. from urllib.parse import urlparse
  12. from flask import current_app, request, jsonify
  13. from functools import wraps
  14. logger = logging.getLogger(__name__)
  15. # PostgreSQL连接池
  16. pg_pool = None
  17. def get_pg_connection():
  18. """
  19. 获取PostgreSQL数据库连接
  20. Returns:
  21. connection: PostgreSQL连接对象
  22. """
  23. global pg_pool
  24. if pg_pool is None:
  25. try:
  26. # 解析SQLAlchemy URI
  27. uri = urlparse(current_app.config['SQLALCHEMY_DATABASE_URI'])
  28. username = uri.username
  29. password = uri.password
  30. database = uri.path[1:] # 移除开头的 '/'
  31. hostname = uri.hostname
  32. port = uri.port or 5432
  33. # 创建连接池
  34. pg_pool = psycopg2.pool.SimpleConnectionPool(
  35. 1, 20,
  36. host=hostname,
  37. database=database,
  38. user=username,
  39. password=password,
  40. port=str(port)
  41. )
  42. logger.info("PostgreSQL连接池初始化成功")
  43. except Exception as e:
  44. logger.error(f"PostgreSQL连接池初始化失败: {str(e)}")
  45. raise
  46. return pg_pool.getconn()
  47. def release_pg_connection(conn):
  48. """
  49. 释放PostgreSQL连接到连接池
  50. Args:
  51. conn: 数据库连接对象
  52. """
  53. global pg_pool
  54. if pg_pool and conn:
  55. pg_pool.putconn(conn)
  56. def encode_password(password):
  57. """
  58. 对密码进行base64编码
  59. Args:
  60. password: 原始密码
  61. Returns:
  62. str: 编码后的密码
  63. """
  64. return base64.b64encode(password.encode('utf-8')).decode('utf-8')
  65. def create_user_table():
  66. """
  67. 创建用户表,如果不存在
  68. Returns:
  69. bool: 是否成功创建
  70. """
  71. conn = None
  72. try:
  73. conn = get_pg_connection()
  74. cursor = conn.cursor()
  75. # 创建用户表
  76. create_table_query = """
  77. CREATE TABLE IF NOT EXISTS users (
  78. id VARCHAR(100) PRIMARY KEY,
  79. username VARCHAR(50) UNIQUE NOT NULL,
  80. password VARCHAR(100) NOT NULL,
  81. created_at FLOAT NOT NULL,
  82. last_login FLOAT,
  83. is_admin BOOLEAN DEFAULT FALSE
  84. );
  85. """
  86. cursor.execute(create_table_query)
  87. # 创建索引加速查询
  88. create_index_query = """
  89. CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
  90. """
  91. cursor.execute(create_index_query)
  92. conn.commit()
  93. cursor.close()
  94. logger.info("用户表创建成功")
  95. return True
  96. except Exception as e:
  97. logger.error(f"创建用户表失败: {str(e)}")
  98. if conn:
  99. conn.rollback()
  100. return False
  101. finally:
  102. if conn:
  103. release_pg_connection(conn)
  104. def register_user(username, password):
  105. """
  106. 注册新用户
  107. Args:
  108. username: 用户名
  109. password: 密码
  110. Returns:
  111. tuple: (是否成功, 消息)
  112. """
  113. conn = None
  114. try:
  115. # 确保表已创建
  116. create_user_table()
  117. # 对密码进行编码
  118. encoded_password = encode_password(password)
  119. # 生成用户ID
  120. user_id = str(uuid.uuid4())
  121. conn = get_pg_connection()
  122. cursor = conn.cursor()
  123. # 检查用户名是否存在
  124. check_query = "SELECT username FROM users WHERE username = %s"
  125. cursor.execute(check_query, (username,))
  126. if cursor.fetchone():
  127. return False, "用户名已存在"
  128. # 创建用户
  129. insert_query = """
  130. INSERT INTO users (id, username, password, created_at, last_login)
  131. VALUES (%s, %s, %s, %s, %s)
  132. """
  133. cursor.execute(
  134. insert_query,
  135. (user_id, username, encoded_password, time.time(), None)
  136. )
  137. conn.commit()
  138. cursor.close()
  139. return True, "注册成功"
  140. except Exception as e:
  141. logger.error(f"用户注册失败: {str(e)}")
  142. if conn:
  143. conn.rollback()
  144. return False, f"注册失败: {str(e)}"
  145. finally:
  146. if conn:
  147. release_pg_connection(conn)
  148. def login_user(username, password):
  149. """
  150. 用户登录验证
  151. Args:
  152. username: 用户名
  153. password: 密码
  154. Returns:
  155. tuple: (是否成功, 用户信息/错误消息)
  156. """
  157. conn = None
  158. try:
  159. # 对输入的密码进行编码
  160. encoded_password = encode_password(password)
  161. conn = get_pg_connection()
  162. cursor = conn.cursor()
  163. # 查询用户
  164. query = """
  165. SELECT id, username, password, created_at, last_login, is_admin
  166. FROM users WHERE username = %s
  167. """
  168. cursor.execute(query, (username,))
  169. user = cursor.fetchone()
  170. # 检查用户是否存在
  171. if not user:
  172. return False, "用户名或密码错误"
  173. # 验证密码
  174. if user[2] != encoded_password:
  175. return False, "用户名或密码错误"
  176. # 更新最后登录时间
  177. current_time = time.time()
  178. update_query = """
  179. UPDATE users SET last_login = %s WHERE username = %s
  180. """
  181. cursor.execute(update_query, (current_time, username))
  182. conn.commit()
  183. # 构建用户信息
  184. user_info = {
  185. "id": user[0],
  186. "username": user[1],
  187. "created_at": user[3],
  188. "last_login": current_time,
  189. "is_admin": user[5] if len(user) > 5 else False
  190. }
  191. cursor.close()
  192. return True, user_info
  193. except Exception as e:
  194. logger.error(f"用户登录失败: {str(e)}")
  195. if conn:
  196. conn.rollback()
  197. return False, f"登录失败: {str(e)}"
  198. finally:
  199. if conn:
  200. release_pg_connection(conn)
  201. def get_user_by_username(username):
  202. """
  203. 根据用户名获取用户信息
  204. Args:
  205. username: 用户名
  206. Returns:
  207. dict: 用户信息(不包含密码)
  208. """
  209. conn = None
  210. try:
  211. conn = get_pg_connection()
  212. cursor = conn.cursor()
  213. query = """
  214. SELECT id, username, created_at, last_login, is_admin
  215. FROM users WHERE username = %s
  216. """
  217. cursor.execute(query, (username,))
  218. user = cursor.fetchone()
  219. cursor.close()
  220. if not user:
  221. return None
  222. user_info = {
  223. "id": user[0],
  224. "username": user[1],
  225. "created_at": user[2],
  226. "last_login": user[3],
  227. "is_admin": user[4] if user[4] is not None else False
  228. }
  229. return user_info
  230. except Exception as e:
  231. logger.error(f"获取用户信息失败: {str(e)}")
  232. return None
  233. finally:
  234. if conn:
  235. release_pg_connection(conn)
  236. def init_db():
  237. """
  238. 初始化数据库,创建用户表
  239. Returns:
  240. bool: 是否成功初始化
  241. """
  242. return create_user_table()
  243. def require_auth(f):
  244. @wraps(f)
  245. def decorated(*args, **kwargs):
  246. auth_header = request.headers.get('Authorization')
  247. if not auth_header:
  248. return jsonify({'message': '缺少认证头'}), 401
  249. try:
  250. # 验证认证头
  251. if auth_header != current_app.config['SECRET_KEY']:
  252. return jsonify({'message': '无效的认证信息'}), 401
  253. return f(*args, **kwargs)
  254. except Exception as e:
  255. return jsonify({'message': '认证失败'}), 401
  256. return decorated