config.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # testing/config.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. import collections
  8. from .. import util
  9. requirements = None
  10. db = None
  11. db_url = None
  12. db_opts = None
  13. file_config = None
  14. test_schema = None
  15. test_schema_2 = None
  16. any_async = False
  17. _current = None
  18. ident = "main"
  19. _fixture_functions = None # installed by plugin_base
  20. def combinations(*comb, **kw):
  21. r"""Deliver multiple versions of a test based on positional combinations.
  22. This is a facade over pytest.mark.parametrize.
  23. :param \*comb: argument combinations. These are tuples that will be passed
  24. positionally to the decorated function.
  25. :param argnames: optional list of argument names. These are the names
  26. of the arguments in the test function that correspond to the entries
  27. in each argument tuple. pytest.mark.parametrize requires this, however
  28. the combinations function will derive it automatically if not present
  29. by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
  30. first argument is "self" which is discarded.
  31. :param id\_: optional id template. This is a string template that
  32. describes how the "id" for each parameter set should be defined, if any.
  33. The number of characters in the template should match the number of
  34. entries in each argument tuple. Each character describes how the
  35. corresponding entry in the argument tuple should be handled, as far as
  36. whether or not it is included in the arguments passed to the function, as
  37. well as if it is included in the tokens used to create the id of the
  38. parameter set.
  39. If omitted, the argument combinations are passed to parametrize as is. If
  40. passed, each argument combination is turned into a pytest.param() object,
  41. mapping the elements of the argument tuple to produce an id based on a
  42. character value in the same position within the string template using the
  43. following scheme::
  44. i - the given argument is a string that is part of the id only, don't
  45. pass it as an argument
  46. n - the given argument should be passed and it should be added to the
  47. id by calling the .__name__ attribute
  48. r - the given argument should be passed and it should be added to the
  49. id by calling repr()
  50. s - the given argument should be passed and it should be added to the
  51. id by calling str()
  52. a - (argument) the given argument should be passed and it should not
  53. be used to generated the id
  54. e.g.::
  55. @testing.combinations(
  56. (operator.eq, "eq"),
  57. (operator.ne, "ne"),
  58. (operator.gt, "gt"),
  59. (operator.lt, "lt"),
  60. id_="na"
  61. )
  62. def test_operator(self, opfunc, name):
  63. pass
  64. The above combination will call ``.__name__`` on the first member of
  65. each tuple and use that as the "id" to pytest.param().
  66. """
  67. return _fixture_functions.combinations(*comb, **kw)
  68. def combinations_list(arg_iterable, **kw):
  69. "As combination, but takes a single iterable"
  70. return combinations(*arg_iterable, **kw)
  71. class Variation(object):
  72. __slots__ = ("_name", "_argname")
  73. def __init__(self, case, argname, case_names):
  74. self._name = case
  75. self._argname = argname
  76. for casename in case_names:
  77. setattr(self, casename, casename == case)
  78. @property
  79. def name(self):
  80. return self._name
  81. def __bool__(self):
  82. return self._name == self._argname
  83. def __nonzero__(self):
  84. return not self.__bool__()
  85. def __str__(self):
  86. return "%s=%r" % (self._argname, self._name)
  87. def __repr__(self):
  88. return str(self)
  89. def fail(self):
  90. # can't import util.fail() under py2.x without resolving
  91. # import cycle
  92. assert False, "Unknown %s" % (self,)
  93. @classmethod
  94. def idfn(cls, variation):
  95. return variation.name
  96. @classmethod
  97. def generate_cases(cls, argname, cases):
  98. case_names = [
  99. argname if c is True else "not_" + argname if c is False else c
  100. for c in cases
  101. ]
  102. typ = type(
  103. argname,
  104. (Variation,),
  105. {
  106. "__slots__": tuple(case_names),
  107. },
  108. )
  109. return [typ(casename, argname, case_names) for casename in case_names]
  110. def variation(argname, cases):
  111. """a helper around testing.combinations that provides a single namespace
  112. that can be used as a switch.
  113. e.g.::
  114. @testing.variation("querytyp", ["select", "subquery", "legacy_query"])
  115. @testing.variation("lazy", ["select", "raise", "raise_on_sql"])
  116. def test_thing(
  117. self,
  118. querytyp,
  119. lazy,
  120. decl_base
  121. ):
  122. class Thing(decl_base):
  123. __tablename__ = 'thing'
  124. # use name directly
  125. rel = relationship("Rel", lazy=lazy.name)
  126. # use as a switch
  127. if querytyp.select:
  128. stmt = select(Thing)
  129. elif querytyp.subquery:
  130. stmt = select(Thing).subquery()
  131. elif querytyp.legacy_query:
  132. stmt = Session.query(Thing)
  133. else:
  134. querytyp.fail()
  135. The variable provided is a slots object of boolean variables, as well
  136. as the name of the case itself under the attribute ".name"
  137. """
  138. cases_plus_limitations = [
  139. entry
  140. if (isinstance(entry, tuple) and len(entry) == 2)
  141. else (entry, None)
  142. for entry in cases
  143. ]
  144. variations = Variation.generate_cases(
  145. argname, [c for c, l in cases_plus_limitations]
  146. )
  147. return combinations(
  148. id_="ia",
  149. argnames=argname,
  150. *[
  151. (variation._name, variation, limitation)
  152. if limitation is not None
  153. else (variation._name, variation)
  154. for variation, (case, limitation) in zip(
  155. variations, cases_plus_limitations
  156. )
  157. ]
  158. )
  159. def variation_fixture(argname, cases, scope="function"):
  160. return fixture(
  161. params=Variation.generate_cases(argname, cases),
  162. ids=Variation.idfn,
  163. scope=scope,
  164. )
  165. def fixture(*arg, **kw):
  166. return _fixture_functions.fixture(*arg, **kw)
  167. def get_current_test_name():
  168. return _fixture_functions.get_current_test_name()
  169. def mark_base_test_class():
  170. return _fixture_functions.mark_base_test_class()
  171. class Config(object):
  172. def __init__(self, db, db_opts, options, file_config):
  173. self._set_name(db)
  174. self.db = db
  175. self.db_opts = db_opts
  176. self.options = options
  177. self.file_config = file_config
  178. self.test_schema = "test_schema"
  179. self.test_schema_2 = "test_schema_2"
  180. self.is_async = db.dialect.is_async and not util.asbool(
  181. db.url.query.get("async_fallback", False)
  182. )
  183. _stack = collections.deque()
  184. _configs = set()
  185. def _set_name(self, db):
  186. if db.dialect.server_version_info:
  187. svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
  188. self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
  189. else:
  190. self.name = "%s+%s" % (db.name, db.driver)
  191. @classmethod
  192. def register(cls, db, db_opts, options, file_config):
  193. """add a config as one of the global configs.
  194. If there are no configs set up yet, this config also
  195. gets set as the "_current".
  196. """
  197. global any_async
  198. cfg = Config(db, db_opts, options, file_config)
  199. # if any backends include an async driver, then ensure
  200. # all setup/teardown and tests are wrapped in the maybe_async()
  201. # decorator that will set up a greenlet context for async drivers.
  202. any_async = any_async or cfg.is_async
  203. cls._configs.add(cfg)
  204. return cfg
  205. @classmethod
  206. def set_as_current(cls, config, namespace):
  207. global db, _current, db_url, test_schema, test_schema_2, db_opts
  208. _current = config
  209. db_url = config.db.url
  210. db_opts = config.db_opts
  211. test_schema = config.test_schema
  212. test_schema_2 = config.test_schema_2
  213. namespace.db = db = config.db
  214. @classmethod
  215. def push_engine(cls, db, namespace):
  216. assert _current, "Can't push without a default Config set up"
  217. cls.push(
  218. Config(
  219. db, _current.db_opts, _current.options, _current.file_config
  220. ),
  221. namespace,
  222. )
  223. @classmethod
  224. def push(cls, config, namespace):
  225. cls._stack.append(_current)
  226. cls.set_as_current(config, namespace)
  227. @classmethod
  228. def pop(cls, namespace):
  229. if cls._stack:
  230. # a failed test w/ -x option can call reset() ahead of time
  231. _current = cls._stack[-1]
  232. del cls._stack[-1]
  233. cls.set_as_current(_current, namespace)
  234. @classmethod
  235. def reset(cls, namespace):
  236. if cls._stack:
  237. cls.set_as_current(cls._stack[0], namespace)
  238. cls._stack.clear()
  239. @classmethod
  240. def all_configs(cls):
  241. return cls._configs
  242. @classmethod
  243. def all_dbs(cls):
  244. for cfg in cls.all_configs():
  245. yield cfg.db
  246. def skip_test(self, msg):
  247. skip_test(msg)
  248. def skip_test(msg):
  249. raise _fixture_functions.skip_test_exception(msg)
  250. def async_test(fn):
  251. return _fixture_functions.async_test(fn)