fixtures.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. # testing/fixtures.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. import contextlib
  8. import itertools
  9. import re
  10. import sys
  11. import sqlalchemy as sa
  12. from . import assertions
  13. from . import config
  14. from . import schema
  15. from .assertions import eq_
  16. from .assertions import ne_
  17. from .entities import BasicEntity
  18. from .entities import ComparableEntity
  19. from .entities import ComparableMixin # noqa
  20. from .util import adict
  21. from .util import drop_all_tables_from_metadata
  22. from .. import event
  23. from .. import util
  24. from ..orm import declarative_base
  25. from ..orm import registry
  26. from ..orm.decl_api import DeclarativeMeta
  27. from ..schema import sort_tables_and_constraints
  28. from ..sql import visitors
  29. from ..sql.elements import ClauseElement
  30. @config.mark_base_test_class()
  31. class TestBase(object):
  32. # A sequence of requirement names matching testing.requires decorators
  33. __requires__ = ()
  34. # A sequence of dialect names to exclude from the test class.
  35. __unsupported_on__ = ()
  36. # If present, test class is only runnable for the *single* specified
  37. # dialect. If you need multiple, use __unsupported_on__ and invert.
  38. __only_on__ = None
  39. # A sequence of no-arg callables. If any are True, the entire testcase is
  40. # skipped.
  41. __skip_if__ = None
  42. # if True, the testing reaper will not attempt to touch connection
  43. # state after a test is completed and before the outer teardown
  44. # starts
  45. __leave_connections_for_teardown__ = False
  46. def assert_(self, val, msg=None):
  47. assert val, msg
  48. @config.fixture()
  49. def nocache(self):
  50. _cache = config.db._compiled_cache
  51. config.db._compiled_cache = None
  52. yield
  53. config.db._compiled_cache = _cache
  54. @config.fixture()
  55. def connection_no_trans(self):
  56. eng = getattr(self, "bind", None) or config.db
  57. with eng.connect() as conn:
  58. yield conn
  59. @config.fixture()
  60. def connection(self):
  61. global _connection_fixture_connection
  62. eng = getattr(self, "bind", None) or config.db
  63. conn = eng.connect()
  64. trans = conn.begin()
  65. _connection_fixture_connection = conn
  66. yield conn
  67. _connection_fixture_connection = None
  68. if trans.is_active:
  69. trans.rollback()
  70. # trans would not be active here if the test is using
  71. # the legacy @provide_metadata decorator still, as it will
  72. # run a close all connections.
  73. conn.close()
  74. @config.fixture()
  75. def close_result_when_finished(self):
  76. to_close = []
  77. to_consume = []
  78. def go(result, consume=False):
  79. to_close.append(result)
  80. if consume:
  81. to_consume.append(result)
  82. yield go
  83. for r in to_consume:
  84. try:
  85. r.all()
  86. except:
  87. pass
  88. for r in to_close:
  89. try:
  90. r.close()
  91. except:
  92. pass
  93. @config.fixture()
  94. def registry(self, metadata):
  95. reg = registry(metadata=metadata)
  96. yield reg
  97. reg.dispose()
  98. @config.fixture
  99. def decl_base(self, registry):
  100. return registry.generate_base()
  101. @config.fixture()
  102. def future_connection(self, future_engine, connection):
  103. # integrate the future_engine and connection fixtures so
  104. # that users of the "connection" fixture will get at the
  105. # "future" connection
  106. yield connection
  107. @config.fixture()
  108. def future_engine(self):
  109. eng = getattr(self, "bind", None) or config.db
  110. with _push_future_engine(eng):
  111. yield
  112. @config.fixture()
  113. def testing_engine(self):
  114. from . import engines
  115. def gen_testing_engine(
  116. url=None,
  117. options=None,
  118. future=None,
  119. asyncio=False,
  120. transfer_staticpool=False,
  121. ):
  122. if options is None:
  123. options = {}
  124. options["scope"] = "fixture"
  125. return engines.testing_engine(
  126. url=url,
  127. options=options,
  128. future=future,
  129. asyncio=asyncio,
  130. transfer_staticpool=transfer_staticpool,
  131. )
  132. yield gen_testing_engine
  133. engines.testing_reaper._drop_testing_engines("fixture")
  134. @config.fixture()
  135. def async_testing_engine(self, testing_engine):
  136. def go(**kw):
  137. kw["asyncio"] = True
  138. return testing_engine(**kw)
  139. return go
  140. @config.fixture
  141. def fixture_session(self):
  142. return fixture_session()
  143. @config.fixture()
  144. def metadata(self, request):
  145. """Provide bound MetaData for a single test, dropping afterwards."""
  146. from ..sql import schema
  147. metadata = schema.MetaData()
  148. request.instance.metadata = metadata
  149. yield metadata
  150. del request.instance.metadata
  151. if (
  152. _connection_fixture_connection
  153. and _connection_fixture_connection.in_transaction()
  154. ):
  155. trans = _connection_fixture_connection.get_transaction()
  156. trans.rollback()
  157. with _connection_fixture_connection.begin():
  158. drop_all_tables_from_metadata(
  159. metadata, _connection_fixture_connection
  160. )
  161. else:
  162. drop_all_tables_from_metadata(metadata, config.db)
  163. @config.fixture(
  164. params=[
  165. (rollback, second_operation, begin_nested)
  166. for rollback in (True, False)
  167. for second_operation in ("none", "execute", "begin")
  168. for begin_nested in (
  169. True,
  170. False,
  171. )
  172. ]
  173. )
  174. def trans_ctx_manager_fixture(self, request, metadata):
  175. rollback, second_operation, begin_nested = request.param
  176. from sqlalchemy import Table, Column, Integer, func, select
  177. from . import eq_
  178. t = Table("test", metadata, Column("data", Integer))
  179. eng = getattr(self, "bind", None) or config.db
  180. t.create(eng)
  181. def run_test(subject, trans_on_subject, execute_on_subject):
  182. with subject.begin() as trans:
  183. if begin_nested:
  184. if not config.requirements.savepoints.enabled:
  185. config.skip_test("savepoints not enabled")
  186. if execute_on_subject:
  187. nested_trans = subject.begin_nested()
  188. else:
  189. nested_trans = trans.begin_nested()
  190. with nested_trans:
  191. if execute_on_subject:
  192. subject.execute(t.insert(), {"data": 10})
  193. else:
  194. trans.execute(t.insert(), {"data": 10})
  195. # for nested trans, we always commit/rollback on the
  196. # "nested trans" object itself.
  197. # only Session(future=False) will affect savepoint
  198. # transaction for session.commit/rollback
  199. if rollback:
  200. nested_trans.rollback()
  201. else:
  202. nested_trans.commit()
  203. if second_operation != "none":
  204. with assertions.expect_raises_message(
  205. sa.exc.InvalidRequestError,
  206. "Can't operate on closed transaction "
  207. "inside context "
  208. "manager. Please complete the context "
  209. "manager "
  210. "before emitting further commands.",
  211. ):
  212. if second_operation == "execute":
  213. if execute_on_subject:
  214. subject.execute(
  215. t.insert(), {"data": 12}
  216. )
  217. else:
  218. trans.execute(t.insert(), {"data": 12})
  219. elif second_operation == "begin":
  220. if execute_on_subject:
  221. subject.begin_nested()
  222. else:
  223. trans.begin_nested()
  224. # outside the nested trans block, but still inside the
  225. # transaction block, we can run SQL, and it will be
  226. # committed
  227. if execute_on_subject:
  228. subject.execute(t.insert(), {"data": 14})
  229. else:
  230. trans.execute(t.insert(), {"data": 14})
  231. else:
  232. if execute_on_subject:
  233. subject.execute(t.insert(), {"data": 10})
  234. else:
  235. trans.execute(t.insert(), {"data": 10})
  236. if trans_on_subject:
  237. if rollback:
  238. subject.rollback()
  239. else:
  240. subject.commit()
  241. else:
  242. if rollback:
  243. trans.rollback()
  244. else:
  245. trans.commit()
  246. if second_operation != "none":
  247. with assertions.expect_raises_message(
  248. sa.exc.InvalidRequestError,
  249. "Can't operate on closed transaction inside "
  250. "context "
  251. "manager. Please complete the context manager "
  252. "before emitting further commands.",
  253. ):
  254. if second_operation == "execute":
  255. if execute_on_subject:
  256. subject.execute(t.insert(), {"data": 12})
  257. else:
  258. trans.execute(t.insert(), {"data": 12})
  259. elif second_operation == "begin":
  260. if hasattr(trans, "begin"):
  261. trans.begin()
  262. else:
  263. subject.begin()
  264. elif second_operation == "begin_nested":
  265. if execute_on_subject:
  266. subject.begin_nested()
  267. else:
  268. trans.begin_nested()
  269. expected_committed = 0
  270. if begin_nested:
  271. # begin_nested variant, we inserted a row after the nested
  272. # block
  273. expected_committed += 1
  274. if not rollback:
  275. # not rollback variant, our row inserted in the target
  276. # block itself would be committed
  277. expected_committed += 1
  278. if execute_on_subject:
  279. eq_(
  280. subject.scalar(select(func.count()).select_from(t)),
  281. expected_committed,
  282. )
  283. else:
  284. with subject.connect() as conn:
  285. eq_(
  286. conn.scalar(select(func.count()).select_from(t)),
  287. expected_committed,
  288. )
  289. return run_test
  290. _connection_fixture_connection = None
  291. @contextlib.contextmanager
  292. def _push_future_engine(engine):
  293. from ..future.engine import Engine
  294. from sqlalchemy import testing
  295. facade = Engine._future_facade(engine)
  296. config._current.push_engine(facade, testing)
  297. yield facade
  298. config._current.pop(testing)
  299. class FutureEngineMixin(object):
  300. @config.fixture(autouse=True, scope="class")
  301. def _push_future_engine(self):
  302. eng = getattr(self, "bind", None) or config.db
  303. with _push_future_engine(eng):
  304. yield
  305. class TablesTest(TestBase):
  306. # 'once', None
  307. run_setup_bind = "once"
  308. # 'once', 'each', None
  309. run_define_tables = "once"
  310. # 'once', 'each', None
  311. run_create_tables = "once"
  312. # 'once', 'each', None
  313. run_inserts = "each"
  314. # 'each', None
  315. run_deletes = "each"
  316. # 'once', None
  317. run_dispose_bind = None
  318. bind = None
  319. _tables_metadata = None
  320. tables = None
  321. other = None
  322. sequences = None
  323. @config.fixture(autouse=True, scope="class")
  324. def _setup_tables_test_class(self):
  325. cls = self.__class__
  326. cls._init_class()
  327. cls._setup_once_tables()
  328. cls._setup_once_inserts()
  329. yield
  330. cls._teardown_once_metadata_bind()
  331. @config.fixture(autouse=True, scope="function")
  332. def _setup_tables_test_instance(self):
  333. self._setup_each_tables()
  334. self._setup_each_inserts()
  335. yield
  336. self._teardown_each_tables()
  337. @property
  338. def tables_test_metadata(self):
  339. return self._tables_metadata
  340. @classmethod
  341. def _init_class(cls):
  342. if cls.run_define_tables == "each":
  343. if cls.run_create_tables == "once":
  344. cls.run_create_tables = "each"
  345. assert cls.run_inserts in ("each", None)
  346. cls.other = adict()
  347. cls.tables = adict()
  348. cls.sequences = adict()
  349. cls.bind = cls.setup_bind()
  350. cls._tables_metadata = sa.MetaData()
  351. @classmethod
  352. def _setup_once_inserts(cls):
  353. if cls.run_inserts == "once":
  354. cls._load_fixtures()
  355. with cls.bind.begin() as conn:
  356. cls.insert_data(conn)
  357. @classmethod
  358. def _setup_once_tables(cls):
  359. if cls.run_define_tables == "once":
  360. cls.define_tables(cls._tables_metadata)
  361. if cls.run_create_tables == "once":
  362. cls._tables_metadata.create_all(cls.bind)
  363. cls.tables.update(cls._tables_metadata.tables)
  364. cls.sequences.update(cls._tables_metadata._sequences)
  365. def _setup_each_tables(self):
  366. if self.run_define_tables == "each":
  367. self.define_tables(self._tables_metadata)
  368. if self.run_create_tables == "each":
  369. self._tables_metadata.create_all(self.bind)
  370. self.tables.update(self._tables_metadata.tables)
  371. self.sequences.update(self._tables_metadata._sequences)
  372. elif self.run_create_tables == "each":
  373. self._tables_metadata.create_all(self.bind)
  374. def _setup_each_inserts(self):
  375. if self.run_inserts == "each":
  376. self._load_fixtures()
  377. with self.bind.begin() as conn:
  378. self.insert_data(conn)
  379. def _teardown_each_tables(self):
  380. if self.run_define_tables == "each":
  381. self.tables.clear()
  382. if self.run_create_tables == "each":
  383. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  384. self._tables_metadata.clear()
  385. elif self.run_create_tables == "each":
  386. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  387. savepoints = getattr(config.requirements, "savepoints", False)
  388. if savepoints:
  389. savepoints = savepoints.enabled
  390. # no need to run deletes if tables are recreated on setup
  391. if (
  392. self.run_define_tables != "each"
  393. and self.run_create_tables != "each"
  394. and self.run_deletes == "each"
  395. ):
  396. with self.bind.begin() as conn:
  397. for table in reversed(
  398. [
  399. t
  400. for (t, fks) in sort_tables_and_constraints(
  401. self._tables_metadata.tables.values()
  402. )
  403. if t is not None
  404. ]
  405. ):
  406. try:
  407. if savepoints:
  408. with conn.begin_nested():
  409. conn.execute(table.delete())
  410. else:
  411. conn.execute(table.delete())
  412. except sa.exc.DBAPIError as ex:
  413. util.print_(
  414. ("Error emptying table %s: %r" % (table, ex)),
  415. file=sys.stderr,
  416. )
  417. @classmethod
  418. def _teardown_once_metadata_bind(cls):
  419. if cls.run_create_tables:
  420. drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
  421. if cls.run_dispose_bind == "once":
  422. cls.dispose_bind(cls.bind)
  423. cls._tables_metadata.bind = None
  424. if cls.run_setup_bind is not None:
  425. cls.bind = None
  426. @classmethod
  427. def setup_bind(cls):
  428. return config.db
  429. @classmethod
  430. def dispose_bind(cls, bind):
  431. if hasattr(bind, "dispose"):
  432. bind.dispose()
  433. elif hasattr(bind, "close"):
  434. bind.close()
  435. @classmethod
  436. def define_tables(cls, metadata):
  437. pass
  438. @classmethod
  439. def fixtures(cls):
  440. return {}
  441. @classmethod
  442. def insert_data(cls, connection):
  443. pass
  444. def sql_count_(self, count, fn):
  445. self.assert_sql_count(self.bind, fn, count)
  446. def sql_eq_(self, callable_, statements):
  447. self.assert_sql(self.bind, callable_, statements)
  448. @classmethod
  449. def _load_fixtures(cls):
  450. """Insert rows as represented by the fixtures() method."""
  451. headers, rows = {}, {}
  452. for table, data in cls.fixtures().items():
  453. if len(data) < 2:
  454. continue
  455. if isinstance(table, util.string_types):
  456. table = cls.tables[table]
  457. headers[table] = data[0]
  458. rows[table] = data[1:]
  459. for table, fks in sort_tables_and_constraints(
  460. cls._tables_metadata.tables.values()
  461. ):
  462. if table is None:
  463. continue
  464. if table not in headers:
  465. continue
  466. with cls.bind.begin() as conn:
  467. conn.execute(
  468. table.insert(),
  469. [
  470. dict(zip(headers[table], column_values))
  471. for column_values in rows[table]
  472. ],
  473. )
  474. class NoCache(object):
  475. @config.fixture(autouse=True, scope="function")
  476. def _disable_cache(self):
  477. _cache = config.db._compiled_cache
  478. config.db._compiled_cache = None
  479. yield
  480. config.db._compiled_cache = _cache
  481. class RemovesEvents(object):
  482. @util.memoized_property
  483. def _event_fns(self):
  484. return set()
  485. def event_listen(self, target, name, fn, **kw):
  486. self._event_fns.add((target, name, fn))
  487. event.listen(target, name, fn, **kw)
  488. @config.fixture(autouse=True, scope="function")
  489. def _remove_events(self):
  490. yield
  491. for key in self._event_fns:
  492. event.remove(*key)
  493. _fixture_sessions = set()
  494. def fixture_session(**kw):
  495. kw.setdefault("autoflush", True)
  496. kw.setdefault("expire_on_commit", True)
  497. bind = kw.pop("bind", config.db)
  498. sess = sa.orm.Session(bind, **kw)
  499. _fixture_sessions.add(sess)
  500. return sess
  501. def _close_all_sessions():
  502. # will close all still-referenced sessions
  503. sa.orm.session.close_all_sessions()
  504. _fixture_sessions.clear()
  505. def stop_test_class_inside_fixtures(cls):
  506. _close_all_sessions()
  507. sa.orm.clear_mappers()
  508. def after_test():
  509. if _fixture_sessions:
  510. _close_all_sessions()
  511. class ORMTest(TestBase):
  512. pass
  513. class MappedTest(TablesTest, assertions.AssertsExecutionResults):
  514. # 'once', 'each', None
  515. run_setup_classes = "once"
  516. # 'once', 'each', None
  517. run_setup_mappers = "each"
  518. classes = None
  519. @config.fixture(autouse=True, scope="class")
  520. def _setup_tables_test_class(self):
  521. cls = self.__class__
  522. cls._init_class()
  523. if cls.classes is None:
  524. cls.classes = adict()
  525. cls._setup_once_tables()
  526. cls._setup_once_classes()
  527. cls._setup_once_mappers()
  528. cls._setup_once_inserts()
  529. yield
  530. cls._teardown_once_class()
  531. cls._teardown_once_metadata_bind()
  532. @config.fixture(autouse=True, scope="function")
  533. def _setup_tables_test_instance(self):
  534. self._setup_each_tables()
  535. self._setup_each_classes()
  536. self._setup_each_mappers()
  537. self._setup_each_inserts()
  538. yield
  539. sa.orm.session.close_all_sessions()
  540. self._teardown_each_mappers()
  541. self._teardown_each_classes()
  542. self._teardown_each_tables()
  543. @classmethod
  544. def _teardown_once_class(cls):
  545. cls.classes.clear()
  546. @classmethod
  547. def _setup_once_classes(cls):
  548. if cls.run_setup_classes == "once":
  549. cls._with_register_classes(cls.setup_classes)
  550. @classmethod
  551. def _setup_once_mappers(cls):
  552. if cls.run_setup_mappers == "once":
  553. cls.mapper_registry, cls.mapper = cls._generate_registry()
  554. cls._with_register_classes(cls.setup_mappers)
  555. def _setup_each_mappers(self):
  556. if self.run_setup_mappers != "once":
  557. (
  558. self.__class__.mapper_registry,
  559. self.__class__.mapper,
  560. ) = self._generate_registry()
  561. if self.run_setup_mappers == "each":
  562. self._with_register_classes(self.setup_mappers)
  563. def _setup_each_classes(self):
  564. if self.run_setup_classes == "each":
  565. self._with_register_classes(self.setup_classes)
  566. @classmethod
  567. def _generate_registry(cls):
  568. decl = registry(metadata=cls._tables_metadata)
  569. return decl, decl.map_imperatively
  570. @classmethod
  571. def _with_register_classes(cls, fn):
  572. """Run a setup method, framing the operation with a Base class
  573. that will catch new subclasses to be established within
  574. the "classes" registry.
  575. """
  576. cls_registry = cls.classes
  577. assert cls_registry is not None
  578. class FindFixture(type):
  579. def __init__(cls, classname, bases, dict_):
  580. cls_registry[classname] = cls
  581. type.__init__(cls, classname, bases, dict_)
  582. class _Base(util.with_metaclass(FindFixture, object)):
  583. pass
  584. class Basic(BasicEntity, _Base):
  585. pass
  586. class Comparable(ComparableEntity, _Base):
  587. pass
  588. cls.Basic = Basic
  589. cls.Comparable = Comparable
  590. fn()
  591. def _teardown_each_mappers(self):
  592. # some tests create mappers in the test bodies
  593. # and will define setup_mappers as None -
  594. # clear mappers in any case
  595. if self.run_setup_mappers != "once":
  596. sa.orm.clear_mappers()
  597. def _teardown_each_classes(self):
  598. if self.run_setup_classes != "once":
  599. self.classes.clear()
  600. @classmethod
  601. def setup_classes(cls):
  602. pass
  603. @classmethod
  604. def setup_mappers(cls):
  605. pass
  606. class DeclarativeMappedTest(MappedTest):
  607. run_setup_classes = "once"
  608. run_setup_mappers = "once"
  609. @classmethod
  610. def _setup_once_tables(cls):
  611. pass
  612. @classmethod
  613. def _with_register_classes(cls, fn):
  614. cls_registry = cls.classes
  615. class FindFixtureDeclarative(DeclarativeMeta):
  616. def __init__(cls, classname, bases, dict_):
  617. cls_registry[classname] = cls
  618. DeclarativeMeta.__init__(cls, classname, bases, dict_)
  619. class DeclarativeBasic(object):
  620. __table_cls__ = schema.Table
  621. _DeclBase = declarative_base(
  622. metadata=cls._tables_metadata,
  623. metaclass=FindFixtureDeclarative,
  624. cls=DeclarativeBasic,
  625. )
  626. cls.DeclarativeBasic = _DeclBase
  627. # sets up cls.Basic which is helpful for things like composite
  628. # classes
  629. super(DeclarativeMappedTest, cls)._with_register_classes(fn)
  630. if cls._tables_metadata.tables and cls.run_create_tables:
  631. cls._tables_metadata.create_all(config.db)
  632. class ComputedReflectionFixtureTest(TablesTest):
  633. run_inserts = run_deletes = None
  634. __backend__ = True
  635. __requires__ = ("computed_columns", "table_reflection")
  636. regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
  637. def normalize(self, text):
  638. return self.regexp.sub("", text).lower()
  639. @classmethod
  640. def define_tables(cls, metadata):
  641. from .. import Integer
  642. from .. import testing
  643. from ..schema import Column
  644. from ..schema import Computed
  645. from ..schema import Table
  646. Table(
  647. "computed_default_table",
  648. metadata,
  649. Column("id", Integer, primary_key=True),
  650. Column("normal", Integer),
  651. Column("computed_col", Integer, Computed("normal + 42")),
  652. Column("with_default", Integer, server_default="42"),
  653. )
  654. t = Table(
  655. "computed_column_table",
  656. metadata,
  657. Column("id", Integer, primary_key=True),
  658. Column("normal", Integer),
  659. Column("computed_no_flag", Integer, Computed("normal + 42")),
  660. )
  661. if testing.requires.schemas.enabled:
  662. t2 = Table(
  663. "computed_column_table",
  664. metadata,
  665. Column("id", Integer, primary_key=True),
  666. Column("normal", Integer),
  667. Column("computed_no_flag", Integer, Computed("normal / 42")),
  668. schema=config.test_schema,
  669. )
  670. if testing.requires.computed_columns_virtual.enabled:
  671. t.append_column(
  672. Column(
  673. "computed_virtual",
  674. Integer,
  675. Computed("normal + 2", persisted=False),
  676. )
  677. )
  678. if testing.requires.schemas.enabled:
  679. t2.append_column(
  680. Column(
  681. "computed_virtual",
  682. Integer,
  683. Computed("normal / 2", persisted=False),
  684. )
  685. )
  686. if testing.requires.computed_columns_stored.enabled:
  687. t.append_column(
  688. Column(
  689. "computed_stored",
  690. Integer,
  691. Computed("normal - 42", persisted=True),
  692. )
  693. )
  694. if testing.requires.schemas.enabled:
  695. t2.append_column(
  696. Column(
  697. "computed_stored",
  698. Integer,
  699. Computed("normal * 42", persisted=True),
  700. )
  701. )
  702. class CacheKeyFixture(object):
  703. def _compare_equal(self, a, b, compare_values):
  704. a_key = a._generate_cache_key()
  705. b_key = b._generate_cache_key()
  706. if a_key is None:
  707. assert a._annotations.get("nocache")
  708. assert b_key is None
  709. else:
  710. eq_(a_key.key, b_key.key)
  711. eq_(hash(a_key.key), hash(b_key.key))
  712. for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
  713. assert a_param.compare(b_param, compare_values=compare_values)
  714. return a_key, b_key
  715. def _run_cache_key_fixture(self, fixture, compare_values):
  716. case_a = fixture()
  717. case_b = fixture()
  718. for a, b in itertools.combinations_with_replacement(
  719. range(len(case_a)), 2
  720. ):
  721. if a == b:
  722. a_key, b_key = self._compare_equal(
  723. case_a[a], case_b[b], compare_values
  724. )
  725. if a_key is None:
  726. continue
  727. else:
  728. a_key = case_a[a]._generate_cache_key()
  729. b_key = case_b[b]._generate_cache_key()
  730. if a_key is None or b_key is None:
  731. if a_key is None:
  732. assert case_a[a]._annotations.get("nocache")
  733. if b_key is None:
  734. assert case_b[b]._annotations.get("nocache")
  735. continue
  736. if a_key.key == b_key.key:
  737. for a_param, b_param in zip(
  738. a_key.bindparams, b_key.bindparams
  739. ):
  740. if not a_param.compare(
  741. b_param, compare_values=compare_values
  742. ):
  743. break
  744. else:
  745. # this fails unconditionally since we could not
  746. # find bound parameter values that differed.
  747. # Usually we intended to get two distinct keys here
  748. # so the failure will be more descriptive using the
  749. # ne_() assertion.
  750. ne_(a_key.key, b_key.key)
  751. else:
  752. ne_(a_key.key, b_key.key)
  753. # ClauseElement-specific test to ensure the cache key
  754. # collected all the bound parameters that aren't marked
  755. # as "literal execute"
  756. if isinstance(case_a[a], ClauseElement) and isinstance(
  757. case_b[b], ClauseElement
  758. ):
  759. assert_a_params = []
  760. assert_b_params = []
  761. for elem in visitors.iterate(case_a[a]):
  762. if elem.__visit_name__ == "bindparam":
  763. assert_a_params.append(elem)
  764. for elem in visitors.iterate(case_b[b]):
  765. if elem.__visit_name__ == "bindparam":
  766. assert_b_params.append(elem)
  767. # note we're asserting the order of the params as well as
  768. # if there are dupes or not. ordering has to be
  769. # deterministic and matches what a traversal would provide.
  770. eq_(
  771. sorted(a_key.bindparams, key=lambda b: b.key),
  772. sorted(
  773. util.unique_list(assert_a_params), key=lambda b: b.key
  774. ),
  775. )
  776. eq_(
  777. sorted(b_key.bindparams, key=lambda b: b.key),
  778. sorted(
  779. util.unique_list(assert_b_params), key=lambda b: b.key
  780. ),
  781. )
  782. def _run_cache_key_equal_fixture(self, fixture, compare_values):
  783. case_a = fixture()
  784. case_b = fixture()
  785. for a, b in itertools.combinations_with_replacement(
  786. range(len(case_a)), 2
  787. ):
  788. self._compare_equal(case_a[a], case_b[b], compare_values)