engines.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # testing/engines.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from __future__ import absolute_import
  8. import collections
  9. import re
  10. import warnings
  11. import weakref
  12. from . import config
  13. from .util import decorator
  14. from .util import gc_collect
  15. from .. import event
  16. from .. import pool
  17. from ..util import await_only
  18. class ConnectionKiller(object):
  19. def __init__(self):
  20. self.proxy_refs = weakref.WeakKeyDictionary()
  21. self.testing_engines = collections.defaultdict(set)
  22. self.dbapi_connections = set()
  23. def add_pool(self, pool):
  24. event.listen(pool, "checkout", self._add_conn)
  25. event.listen(pool, "checkin", self._remove_conn)
  26. event.listen(pool, "close", self._remove_conn)
  27. event.listen(pool, "close_detached", self._remove_conn)
  28. # note we are keeping "invalidated" here, as those are still
  29. # opened connections we would like to roll back
  30. def _add_conn(self, dbapi_con, con_record, con_proxy):
  31. self.dbapi_connections.add(dbapi_con)
  32. self.proxy_refs[con_proxy] = True
  33. def _remove_conn(self, dbapi_conn, *arg):
  34. self.dbapi_connections.discard(dbapi_conn)
  35. def add_engine(self, engine, scope):
  36. self.add_pool(engine.pool)
  37. assert scope in ("class", "global", "function", "fixture")
  38. self.testing_engines[scope].add(engine)
  39. def _safe(self, fn):
  40. try:
  41. fn()
  42. except Exception as e:
  43. warnings.warn(
  44. "testing_reaper couldn't rollback/close connection: %s" % e
  45. )
  46. def rollback_all(self):
  47. for rec in list(self.proxy_refs):
  48. if rec is not None and rec.is_valid:
  49. self._safe(rec.rollback)
  50. def checkin_all(self):
  51. # run pool.checkin() for all ConnectionFairy instances we have
  52. # tracked.
  53. for rec in list(self.proxy_refs):
  54. if rec is not None and rec.is_valid:
  55. self.dbapi_connections.discard(rec.dbapi_connection)
  56. self._safe(rec._checkin)
  57. # for fairy refs that were GCed and could not close the connection,
  58. # such as asyncio, roll back those remaining connections
  59. for con in self.dbapi_connections:
  60. self._safe(con.rollback)
  61. self.dbapi_connections.clear()
  62. def close_all(self):
  63. self.checkin_all()
  64. def prepare_for_drop_tables(self, connection):
  65. # don't do aggressive checks for third party test suites
  66. if not config.bootstrapped_as_sqlalchemy:
  67. return
  68. from . import provision
  69. provision.prepare_for_drop_tables(connection.engine.url, connection)
  70. def _drop_testing_engines(self, scope):
  71. eng = self.testing_engines[scope]
  72. for rec in list(eng):
  73. for proxy_ref in list(self.proxy_refs):
  74. if proxy_ref is not None and proxy_ref.is_valid:
  75. if (
  76. proxy_ref._pool is not None
  77. and proxy_ref._pool is rec.pool
  78. ):
  79. self._safe(proxy_ref._checkin)
  80. if hasattr(rec, "sync_engine"):
  81. await_only(rec.dispose())
  82. else:
  83. rec.dispose()
  84. eng.clear()
  85. def after_test(self):
  86. self._drop_testing_engines("function")
  87. def after_test_outside_fixtures(self, test):
  88. # don't do aggressive checks for third party test suites
  89. if not config.bootstrapped_as_sqlalchemy:
  90. return
  91. if test.__class__.__leave_connections_for_teardown__:
  92. return
  93. self.checkin_all()
  94. # on PostgreSQL, this will test for any "idle in transaction"
  95. # connections. useful to identify tests with unusual patterns
  96. # that can't be cleaned up correctly.
  97. from . import provision
  98. with config.db.connect() as conn:
  99. provision.prepare_for_drop_tables(conn.engine.url, conn)
  100. def stop_test_class_inside_fixtures(self):
  101. self.checkin_all()
  102. self._drop_testing_engines("function")
  103. self._drop_testing_engines("class")
  104. def stop_test_class_outside_fixtures(self):
  105. # ensure no refs to checked out connections at all.
  106. if pool.base._strong_ref_connection_records:
  107. gc_collect()
  108. if pool.base._strong_ref_connection_records:
  109. ln = len(pool.base._strong_ref_connection_records)
  110. pool.base._strong_ref_connection_records.clear()
  111. assert (
  112. False
  113. ), "%d connection recs not cleared after test suite" % (ln)
  114. def final_cleanup(self):
  115. self.checkin_all()
  116. for scope in self.testing_engines:
  117. self._drop_testing_engines(scope)
  118. def assert_all_closed(self):
  119. for rec in self.proxy_refs:
  120. if rec.is_valid:
  121. assert False
  122. testing_reaper = ConnectionKiller()
  123. @decorator
  124. def assert_conns_closed(fn, *args, **kw):
  125. try:
  126. fn(*args, **kw)
  127. finally:
  128. testing_reaper.assert_all_closed()
  129. @decorator
  130. def rollback_open_connections(fn, *args, **kw):
  131. """Decorator that rolls back all open connections after fn execution."""
  132. try:
  133. fn(*args, **kw)
  134. finally:
  135. testing_reaper.rollback_all()
  136. @decorator
  137. def close_first(fn, *args, **kw):
  138. """Decorator that closes all connections before fn execution."""
  139. testing_reaper.checkin_all()
  140. fn(*args, **kw)
  141. @decorator
  142. def close_open_connections(fn, *args, **kw):
  143. """Decorator that closes all connections after fn execution."""
  144. try:
  145. fn(*args, **kw)
  146. finally:
  147. testing_reaper.checkin_all()
  148. def all_dialects(exclude=None):
  149. import sqlalchemy.dialects as d
  150. for name in d.__all__:
  151. # TEMPORARY
  152. if exclude and name in exclude:
  153. continue
  154. mod = getattr(d, name, None)
  155. if not mod:
  156. mod = getattr(
  157. __import__("sqlalchemy.dialects.%s" % name).dialects, name
  158. )
  159. yield mod.dialect()
  160. class ReconnectFixture(object):
  161. def __init__(self, dbapi):
  162. self.dbapi = dbapi
  163. self.connections = []
  164. self.is_stopped = False
  165. def __getattr__(self, key):
  166. return getattr(self.dbapi, key)
  167. def connect(self, *args, **kwargs):
  168. conn = self.dbapi.connect(*args, **kwargs)
  169. if self.is_stopped:
  170. self._safe(conn.close)
  171. curs = conn.cursor() # should fail on Oracle etc.
  172. # should fail for everything that didn't fail
  173. # above, connection is closed
  174. curs.execute("select 1")
  175. assert False, "simulated connect failure didn't work"
  176. else:
  177. self.connections.append(conn)
  178. return conn
  179. def _safe(self, fn):
  180. try:
  181. fn()
  182. except Exception as e:
  183. warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
  184. def shutdown(self, stop=False):
  185. # TODO: this doesn't cover all cases
  186. # as nicely as we'd like, namely MySQLdb.
  187. # would need to implement R. Brewer's
  188. # proxy server idea to get better
  189. # coverage.
  190. self.is_stopped = stop
  191. for c in list(self.connections):
  192. self._safe(c.close)
  193. self.connections = []
  194. def restart(self):
  195. self.is_stopped = False
  196. def reconnecting_engine(url=None, options=None):
  197. url = url or config.db.url
  198. dbapi = config.db.dialect.dbapi
  199. if not options:
  200. options = {}
  201. options["module"] = ReconnectFixture(dbapi)
  202. engine = testing_engine(url, options)
  203. _dispose = engine.dispose
  204. def dispose():
  205. engine.dialect.dbapi.shutdown()
  206. engine.dialect.dbapi.is_stopped = False
  207. _dispose()
  208. engine.test_shutdown = engine.dialect.dbapi.shutdown
  209. engine.test_restart = engine.dialect.dbapi.restart
  210. engine.dispose = dispose
  211. return engine
  212. def testing_engine(
  213. url=None,
  214. options=None,
  215. future=None,
  216. asyncio=False,
  217. transfer_staticpool=False,
  218. _sqlite_savepoint=False,
  219. ):
  220. """Produce an engine configured by --options with optional overrides."""
  221. if asyncio:
  222. assert not _sqlite_savepoint
  223. from sqlalchemy.ext.asyncio import (
  224. create_async_engine as create_engine,
  225. )
  226. elif future or (
  227. config.db and config.db._is_future and future is not False
  228. ):
  229. from sqlalchemy.future import create_engine
  230. else:
  231. from sqlalchemy import create_engine
  232. from sqlalchemy.engine.url import make_url
  233. if not options:
  234. use_reaper = True
  235. scope = "function"
  236. sqlite_savepoint = False
  237. else:
  238. use_reaper = options.pop("use_reaper", True)
  239. scope = options.pop("scope", "function")
  240. sqlite_savepoint = options.pop("sqlite_savepoint", False)
  241. url = url or config.db.url
  242. url = make_url(url)
  243. if options is None:
  244. if config.db is None or url.drivername == config.db.url.drivername:
  245. options = config.db_opts
  246. else:
  247. options = {}
  248. elif config.db is not None and url.drivername == config.db.url.drivername:
  249. default_opt = config.db_opts.copy()
  250. default_opt.update(options)
  251. engine = create_engine(url, **options)
  252. if sqlite_savepoint and engine.name == "sqlite":
  253. # apply SQLite savepoint workaround
  254. @event.listens_for(engine, "connect")
  255. def do_connect(dbapi_connection, connection_record):
  256. dbapi_connection.isolation_level = None
  257. @event.listens_for(engine, "begin")
  258. def do_begin(conn):
  259. conn.exec_driver_sql("BEGIN")
  260. if transfer_staticpool:
  261. from sqlalchemy.pool import StaticPool
  262. if config.db is not None and isinstance(config.db.pool, StaticPool):
  263. use_reaper = False
  264. engine.pool._transfer_from(config.db.pool)
  265. if scope == "global":
  266. if asyncio:
  267. engine.sync_engine._has_events = True
  268. else:
  269. engine._has_events = (
  270. True # enable event blocks, helps with profiling
  271. )
  272. if isinstance(engine.pool, pool.QueuePool):
  273. engine.pool._timeout = 0
  274. engine.pool._max_overflow = 0
  275. if use_reaper:
  276. testing_reaper.add_engine(engine, scope)
  277. return engine
  278. def mock_engine(dialect_name=None):
  279. """Provides a mocking engine based on the current testing.db.
  280. This is normally used to test DDL generation flow as emitted
  281. by an Engine.
  282. It should not be used in other cases, as assert_compile() and
  283. assert_sql_execution() are much better choices with fewer
  284. moving parts.
  285. """
  286. from sqlalchemy import create_mock_engine
  287. if not dialect_name:
  288. dialect_name = config.db.name
  289. buffer = []
  290. def executor(sql, *a, **kw):
  291. buffer.append(sql)
  292. def assert_sql(stmts):
  293. recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
  294. assert recv == stmts, recv
  295. def print_sql():
  296. d = engine.dialect
  297. return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
  298. engine = create_mock_engine(dialect_name + "://", executor)
  299. assert not hasattr(engine, "mock")
  300. engine.mock = buffer
  301. engine.assert_sql = assert_sql
  302. engine.print_sql = print_sql
  303. return engine
  304. class DBAPIProxyCursor(object):
  305. """Proxy a DBAPI cursor.
  306. Tests can provide subclasses of this to intercept
  307. DBAPI-level cursor operations.
  308. """
  309. def __init__(self, engine, conn, *args, **kwargs):
  310. self.engine = engine
  311. self.connection = conn
  312. self.cursor = conn.cursor(*args, **kwargs)
  313. def execute(self, stmt, parameters=None, **kw):
  314. if parameters:
  315. return self.cursor.execute(stmt, parameters, **kw)
  316. else:
  317. return self.cursor.execute(stmt, **kw)
  318. def executemany(self, stmt, params, **kw):
  319. return self.cursor.executemany(stmt, params, **kw)
  320. def __iter__(self):
  321. return iter(self.cursor)
  322. def __getattr__(self, key):
  323. return getattr(self.cursor, key)
  324. class DBAPIProxyConnection(object):
  325. """Proxy a DBAPI connection.
  326. Tests can provide subclasses of this to intercept
  327. DBAPI-level connection operations.
  328. """
  329. def __init__(self, engine, cursor_cls):
  330. self.conn = engine.pool._creator()
  331. self.engine = engine
  332. self.cursor_cls = cursor_cls
  333. def cursor(self, *args, **kwargs):
  334. return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
  335. def close(self):
  336. self.conn.close()
  337. def __getattr__(self, key):
  338. return getattr(self.conn, key)
  339. def proxying_engine(
  340. conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
  341. ):
  342. """Produce an engine that provides proxy hooks for
  343. common methods.
  344. """
  345. def mock_conn():
  346. return conn_cls(config.db, cursor_cls)
  347. def _wrap_do_on_connect(do_on_connect):
  348. def go(dbapi_conn):
  349. return do_on_connect(dbapi_conn.conn)
  350. return go
  351. return testing_engine(
  352. options={
  353. "creator": mock_conn,
  354. "_wrap_do_on_connect": _wrap_do_on_connect,
  355. }
  356. )