auth.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """
  2. 系统用户认证模块
  3. 提供用户注册、登录验证等功能
  4. """
  5. import logging
  6. import base64
  7. import time
  8. import uuid
  9. import psycopg2
  10. from psycopg2 import pool
  11. from urllib.parse import urlparse
  12. from app.config.config import Config
  13. logger = logging.getLogger(__name__)
  14. # PostgreSQL连接池
  15. pg_pool = None
  16. def get_pg_connection():
  17. """
  18. 获取PostgreSQL数据库连接
  19. Returns:
  20. connection: PostgreSQL连接对象
  21. """
  22. global pg_pool
  23. if pg_pool is None:
  24. try:
  25. # 解析SQLAlchemy URI
  26. uri = urlparse(Config.SQLALCHEMY_DATABASE_URI)
  27. username = uri.username
  28. password = uri.password
  29. database = uri.path[1:] # 移除开头的 '/'
  30. hostname = uri.hostname
  31. port = uri.port or 5432
  32. # 创建连接池
  33. pg_pool = psycopg2.pool.SimpleConnectionPool(
  34. 1, 20,
  35. host=hostname,
  36. database=database,
  37. user=username,
  38. password=password,
  39. port=str(port)
  40. )
  41. logger.info("PostgreSQL连接池初始化成功")
  42. except Exception as e:
  43. logger.error(f"PostgreSQL连接池初始化失败: {str(e)}")
  44. raise
  45. return pg_pool.getconn()
  46. def release_pg_connection(conn):
  47. """
  48. 释放PostgreSQL连接到连接池
  49. Args:
  50. conn: 数据库连接对象
  51. """
  52. global pg_pool
  53. if pg_pool and conn:
  54. pg_pool.putconn(conn)
  55. def encode_password(password):
  56. """
  57. 对密码进行base64编码
  58. Args:
  59. password: 原始密码
  60. Returns:
  61. str: 编码后的密码
  62. """
  63. return base64.b64encode(password.encode('utf-8')).decode('utf-8')
  64. def create_user_table():
  65. """
  66. 创建用户表,如果不存在
  67. Returns:
  68. bool: 是否成功创建
  69. """
  70. conn = None
  71. try:
  72. conn = get_pg_connection()
  73. cursor = conn.cursor()
  74. # 创建用户表
  75. create_table_query = """
  76. CREATE TABLE IF NOT EXISTS users (
  77. id VARCHAR(100) PRIMARY KEY,
  78. username VARCHAR(50) UNIQUE NOT NULL,
  79. password VARCHAR(100) NOT NULL,
  80. created_at FLOAT NOT NULL,
  81. last_login FLOAT,
  82. is_admin BOOLEAN DEFAULT FALSE
  83. );
  84. """
  85. cursor.execute(create_table_query)
  86. # 创建索引加速查询
  87. create_index_query = """
  88. CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
  89. """
  90. cursor.execute(create_index_query)
  91. conn.commit()
  92. cursor.close()
  93. logger.info("用户表创建成功")
  94. return True
  95. except Exception as e:
  96. logger.error(f"创建用户表失败: {str(e)}")
  97. if conn:
  98. conn.rollback()
  99. return False
  100. finally:
  101. if conn:
  102. release_pg_connection(conn)
  103. def register_user(username, password):
  104. """
  105. 注册新用户
  106. Args:
  107. username: 用户名
  108. password: 密码
  109. Returns:
  110. tuple: (是否成功, 消息)
  111. """
  112. conn = None
  113. try:
  114. # 确保表已创建
  115. create_user_table()
  116. # 对密码进行编码
  117. encoded_password = encode_password(password)
  118. # 生成用户ID
  119. user_id = str(uuid.uuid4())
  120. conn = get_pg_connection()
  121. cursor = conn.cursor()
  122. # 检查用户名是否存在
  123. check_query = "SELECT username FROM users WHERE username = %s"
  124. cursor.execute(check_query, (username,))
  125. if cursor.fetchone():
  126. return False, "用户名已存在"
  127. # 创建用户
  128. insert_query = """
  129. INSERT INTO users (id, username, password, created_at, last_login)
  130. VALUES (%s, %s, %s, %s, %s)
  131. """
  132. cursor.execute(
  133. insert_query,
  134. (user_id, username, encoded_password, time.time(), None)
  135. )
  136. conn.commit()
  137. cursor.close()
  138. return True, "注册成功"
  139. except Exception as e:
  140. logger.error(f"用户注册失败: {str(e)}")
  141. if conn:
  142. conn.rollback()
  143. return False, f"注册失败: {str(e)}"
  144. finally:
  145. if conn:
  146. release_pg_connection(conn)
  147. def login_user(username, password):
  148. """
  149. 用户登录验证
  150. Args:
  151. username: 用户名
  152. password: 密码
  153. Returns:
  154. tuple: (是否成功, 用户信息/错误消息)
  155. """
  156. conn = None
  157. try:
  158. # 对输入的密码进行编码
  159. encoded_password = encode_password(password)
  160. conn = get_pg_connection()
  161. cursor = conn.cursor()
  162. # 查询用户
  163. query = """
  164. SELECT id, username, password, created_at, last_login, is_admin
  165. FROM users WHERE username = %s
  166. """
  167. cursor.execute(query, (username,))
  168. user = cursor.fetchone()
  169. # 检查用户是否存在
  170. if not user:
  171. return False, "用户名或密码错误"
  172. # 验证密码
  173. if user[2] != encoded_password:
  174. return False, "用户名或密码错误"
  175. # 更新最后登录时间
  176. current_time = time.time()
  177. update_query = """
  178. UPDATE users SET last_login = %s WHERE username = %s
  179. """
  180. cursor.execute(update_query, (current_time, username))
  181. conn.commit()
  182. # 构建用户信息
  183. user_info = {
  184. "id": user[0],
  185. "username": user[1],
  186. "created_at": user[3],
  187. "last_login": current_time,
  188. "is_admin": user[5] if len(user) > 5 else False
  189. }
  190. cursor.close()
  191. return True, user_info
  192. except Exception as e:
  193. logger.error(f"用户登录失败: {str(e)}")
  194. if conn:
  195. conn.rollback()
  196. return False, f"登录失败: {str(e)}"
  197. finally:
  198. if conn:
  199. release_pg_connection(conn)
  200. def get_user_by_username(username):
  201. """
  202. 根据用户名获取用户信息
  203. Args:
  204. username: 用户名
  205. Returns:
  206. dict: 用户信息(不包含密码)
  207. """
  208. conn = None
  209. try:
  210. conn = get_pg_connection()
  211. cursor = conn.cursor()
  212. query = """
  213. SELECT id, username, created_at, last_login, is_admin
  214. FROM users WHERE username = %s
  215. """
  216. cursor.execute(query, (username,))
  217. user = cursor.fetchone()
  218. cursor.close()
  219. if not user:
  220. return None
  221. user_info = {
  222. "id": user[0],
  223. "username": user[1],
  224. "created_at": user[2],
  225. "last_login": user[3],
  226. "is_admin": user[4] if user[4] is not None else False
  227. }
  228. return user_info
  229. except Exception as e:
  230. logger.error(f"获取用户信息失败: {str(e)}")
  231. return None
  232. finally:
  233. if conn:
  234. release_pg_connection(conn)
  235. def init_db():
  236. """
  237. 初始化数据库,创建用户表
  238. Returns:
  239. bool: 是否成功初始化
  240. """
  241. return create_user_table()