pytest_plugin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import asyncio
  2. import contextlib
  3. import inspect
  4. import warnings
  5. from typing import (
  6. Any,
  7. Awaitable,
  8. Callable,
  9. Dict,
  10. Iterator,
  11. Optional,
  12. Protocol,
  13. Type,
  14. Union,
  15. overload,
  16. )
  17. import pytest
  18. from .test_utils import (
  19. BaseTestServer,
  20. RawTestServer,
  21. TestClient,
  22. TestServer,
  23. loop_context,
  24. setup_test_loop,
  25. teardown_test_loop,
  26. unused_port as _unused_port,
  27. )
  28. from .web import Application, BaseRequest, Request
  29. from .web_protocol import _RequestHandler
  30. try:
  31. import uvloop
  32. except ImportError: # pragma: no cover
  33. uvloop = None # type: ignore[assignment]
  34. class AiohttpClient(Protocol):
  35. @overload
  36. async def __call__(
  37. self,
  38. __param: Application,
  39. *,
  40. server_kwargs: Optional[Dict[str, Any]] = None,
  41. **kwargs: Any,
  42. ) -> TestClient[Request, Application]: ...
  43. @overload
  44. async def __call__(
  45. self,
  46. __param: BaseTestServer,
  47. *,
  48. server_kwargs: Optional[Dict[str, Any]] = None,
  49. **kwargs: Any,
  50. ) -> TestClient[BaseRequest, None]: ...
  51. class AiohttpServer(Protocol):
  52. def __call__(
  53. self, app: Application, *, port: Optional[int] = None, **kwargs: Any
  54. ) -> Awaitable[TestServer]: ...
  55. class AiohttpRawServer(Protocol):
  56. def __call__(
  57. self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
  58. ) -> Awaitable[RawTestServer]: ...
  59. def pytest_addoption(parser): # type: ignore[no-untyped-def]
  60. parser.addoption(
  61. "--aiohttp-fast",
  62. action="store_true",
  63. default=False,
  64. help="run tests faster by disabling extra checks",
  65. )
  66. parser.addoption(
  67. "--aiohttp-loop",
  68. action="store",
  69. default="pyloop",
  70. help="run tests with specific loop: pyloop, uvloop or all",
  71. )
  72. parser.addoption(
  73. "--aiohttp-enable-loop-debug",
  74. action="store_true",
  75. default=False,
  76. help="enable event loop debug mode",
  77. )
  78. def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
  79. """Set up pytest fixture.
  80. Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
  81. """
  82. func = fixturedef.func
  83. if inspect.isasyncgenfunction(func):
  84. # async generator fixture
  85. is_async_gen = True
  86. elif inspect.iscoroutinefunction(func):
  87. # regular async fixture
  88. is_async_gen = False
  89. else:
  90. # not an async fixture, nothing to do
  91. return
  92. strip_request = False
  93. if "request" not in fixturedef.argnames:
  94. fixturedef.argnames += ("request",)
  95. strip_request = True
  96. def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
  97. request = kwargs["request"]
  98. if strip_request:
  99. del kwargs["request"]
  100. # if neither the fixture nor the test use the 'loop' fixture,
  101. # 'getfixturevalue' will fail because the test is not parameterized
  102. # (this can be removed someday if 'loop' is no longer parameterized)
  103. if "loop" not in request.fixturenames:
  104. raise Exception(
  105. "Asynchronous fixtures must depend on the 'loop' fixture or "
  106. "be used in tests depending from it."
  107. )
  108. _loop = request.getfixturevalue("loop")
  109. if is_async_gen:
  110. # for async generators, we need to advance the generator once,
  111. # then advance it again in a finalizer
  112. gen = func(*args, **kwargs)
  113. def finalizer(): # type: ignore[no-untyped-def]
  114. try:
  115. return _loop.run_until_complete(gen.__anext__())
  116. except StopAsyncIteration:
  117. pass
  118. request.addfinalizer(finalizer)
  119. return _loop.run_until_complete(gen.__anext__())
  120. else:
  121. return _loop.run_until_complete(func(*args, **kwargs))
  122. fixturedef.func = wrapper
  123. @pytest.fixture
  124. def fast(request): # type: ignore[no-untyped-def]
  125. """--fast config option"""
  126. return request.config.getoption("--aiohttp-fast")
  127. @pytest.fixture
  128. def loop_debug(request): # type: ignore[no-untyped-def]
  129. """--enable-loop-debug config option"""
  130. return request.config.getoption("--aiohttp-enable-loop-debug")
  131. @contextlib.contextmanager
  132. def _runtime_warning_context(): # type: ignore[no-untyped-def]
  133. """Context manager which checks for RuntimeWarnings.
  134. This exists specifically to
  135. avoid "coroutine 'X' was never awaited" warnings being missed.
  136. If RuntimeWarnings occur in the context a RuntimeError is raised.
  137. """
  138. with warnings.catch_warnings(record=True) as _warnings:
  139. yield
  140. rw = [
  141. "{w.filename}:{w.lineno}:{w.message}".format(w=w)
  142. for w in _warnings
  143. if w.category == RuntimeWarning
  144. ]
  145. if rw:
  146. raise RuntimeError(
  147. "{} Runtime Warning{},\n{}".format(
  148. len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
  149. )
  150. )
  151. @contextlib.contextmanager
  152. def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def]
  153. """Passthrough loop context.
  154. Sets up and tears down a loop unless one is passed in via the loop
  155. argument when it's passed straight through.
  156. """
  157. if loop:
  158. # loop already exists, pass it straight through
  159. yield loop
  160. else:
  161. # this shadows loop_context's standard behavior
  162. loop = setup_test_loop()
  163. yield loop
  164. teardown_test_loop(loop, fast=fast)
  165. def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
  166. """Fix pytest collecting for coroutines."""
  167. if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
  168. return list(collector._genfunctions(name, obj))
  169. def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def]
  170. """Run coroutines in an event loop instead of a normal function call."""
  171. fast = pyfuncitem.config.getoption("--aiohttp-fast")
  172. if inspect.iscoroutinefunction(pyfuncitem.function):
  173. existing_loop = pyfuncitem.funcargs.get(
  174. "proactor_loop"
  175. ) or pyfuncitem.funcargs.get("loop", None)
  176. with _runtime_warning_context():
  177. with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
  178. testargs = {
  179. arg: pyfuncitem.funcargs[arg]
  180. for arg in pyfuncitem._fixtureinfo.argnames
  181. }
  182. _loop.run_until_complete(pyfuncitem.obj(**testargs))
  183. return True
  184. def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
  185. if "loop_factory" not in metafunc.fixturenames:
  186. return
  187. loops = metafunc.config.option.aiohttp_loop
  188. avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]]
  189. avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}
  190. if uvloop is not None: # pragma: no cover
  191. avail_factories["uvloop"] = uvloop.EventLoopPolicy
  192. if loops == "all":
  193. loops = "pyloop,uvloop?"
  194. factories = {} # type: ignore[var-annotated]
  195. for name in loops.split(","):
  196. required = not name.endswith("?")
  197. name = name.strip(" ?")
  198. if name not in avail_factories: # pragma: no cover
  199. if required:
  200. raise ValueError(
  201. "Unknown loop '%s', available loops: %s"
  202. % (name, list(factories.keys()))
  203. )
  204. else:
  205. continue
  206. factories[name] = avail_factories[name]
  207. metafunc.parametrize(
  208. "loop_factory", list(factories.values()), ids=list(factories.keys())
  209. )
  210. @pytest.fixture
  211. def loop(loop_factory, fast, loop_debug): # type: ignore[no-untyped-def]
  212. """Return an instance of the event loop."""
  213. policy = loop_factory()
  214. asyncio.set_event_loop_policy(policy)
  215. with loop_context(fast=fast) as _loop:
  216. if loop_debug:
  217. _loop.set_debug(True) # pragma: no cover
  218. asyncio.set_event_loop(_loop)
  219. yield _loop
  220. @pytest.fixture
  221. def proactor_loop(): # type: ignore[no-untyped-def]
  222. policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined]
  223. asyncio.set_event_loop_policy(policy)
  224. with loop_context(policy.new_event_loop) as _loop:
  225. asyncio.set_event_loop(_loop)
  226. yield _loop
  227. @pytest.fixture
  228. def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
  229. warnings.warn(
  230. "Deprecated, use aiohttp_unused_port fixture instead",
  231. DeprecationWarning,
  232. stacklevel=2,
  233. )
  234. return aiohttp_unused_port
  235. @pytest.fixture
  236. def aiohttp_unused_port() -> Callable[[], int]:
  237. """Return a port that is unused on the current host."""
  238. return _unused_port
  239. @pytest.fixture
  240. def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
  241. """Factory to create a TestServer instance, given an app.
  242. aiohttp_server(app, **kwargs)
  243. """
  244. servers = []
  245. async def go(
  246. app: Application, *, port: Optional[int] = None, **kwargs: Any
  247. ) -> TestServer:
  248. server = TestServer(app, port=port)
  249. await server.start_server(loop=loop, **kwargs)
  250. servers.append(server)
  251. return server
  252. yield go
  253. async def finalize() -> None:
  254. while servers:
  255. await servers.pop().close()
  256. loop.run_until_complete(finalize())
  257. @pytest.fixture
  258. def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover
  259. warnings.warn(
  260. "Deprecated, use aiohttp_server fixture instead",
  261. DeprecationWarning,
  262. stacklevel=2,
  263. )
  264. return aiohttp_server
  265. @pytest.fixture
  266. def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
  267. """Factory to create a RawTestServer instance, given a web handler.
  268. aiohttp_raw_server(handler, **kwargs)
  269. """
  270. servers = []
  271. async def go(
  272. handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
  273. ) -> RawTestServer:
  274. server = RawTestServer(handler, port=port)
  275. await server.start_server(loop=loop, **kwargs)
  276. servers.append(server)
  277. return server
  278. yield go
  279. async def finalize() -> None:
  280. while servers:
  281. await servers.pop().close()
  282. loop.run_until_complete(finalize())
  283. @pytest.fixture
  284. def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover
  285. aiohttp_raw_server,
  286. ):
  287. warnings.warn(
  288. "Deprecated, use aiohttp_raw_server fixture instead",
  289. DeprecationWarning,
  290. stacklevel=2,
  291. )
  292. return aiohttp_raw_server
  293. @pytest.fixture
  294. def aiohttp_client(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpClient]:
  295. """Factory to create a TestClient instance.
  296. aiohttp_client(app, **kwargs)
  297. aiohttp_client(server, **kwargs)
  298. aiohttp_client(raw_server, **kwargs)
  299. """
  300. clients = []
  301. @overload
  302. async def go(
  303. __param: Application,
  304. *,
  305. server_kwargs: Optional[Dict[str, Any]] = None,
  306. **kwargs: Any,
  307. ) -> TestClient[Request, Application]: ...
  308. @overload
  309. async def go(
  310. __param: BaseTestServer,
  311. *,
  312. server_kwargs: Optional[Dict[str, Any]] = None,
  313. **kwargs: Any,
  314. ) -> TestClient[BaseRequest, None]: ...
  315. async def go(
  316. __param: Union[Application, BaseTestServer],
  317. *args: Any,
  318. server_kwargs: Optional[Dict[str, Any]] = None,
  319. **kwargs: Any,
  320. ) -> TestClient[Any, Any]:
  321. if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
  322. __param, (Application, BaseTestServer)
  323. ):
  324. __param = __param(loop, *args, **kwargs)
  325. kwargs = {}
  326. else:
  327. assert not args, "args should be empty"
  328. if isinstance(__param, Application):
  329. server_kwargs = server_kwargs or {}
  330. server = TestServer(__param, loop=loop, **server_kwargs)
  331. client = TestClient(server, loop=loop, **kwargs)
  332. elif isinstance(__param, BaseTestServer):
  333. client = TestClient(__param, loop=loop, **kwargs)
  334. else:
  335. raise ValueError("Unknown argument type: %r" % type(__param))
  336. await client.start_server()
  337. clients.append(client)
  338. return client
  339. yield go
  340. async def finalize() -> None:
  341. while clients:
  342. await clients.pop().close()
  343. loop.run_until_complete(finalize())
  344. @pytest.fixture
  345. def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover
  346. warnings.warn(
  347. "Deprecated, use aiohttp_client fixture instead",
  348. DeprecationWarning,
  349. stacklevel=2,
  350. )
  351. return aiohttp_client