test_utils.py 22 KB


  1. """Utilities shared by tests."""
  2. import asyncio
  3. import contextlib
  4. import gc
  5. import inspect
  6. import ipaddress
  7. import os
  8. import socket
  9. import sys
  10. import warnings
  11. from abc import ABC, abstractmethod
  12. from types import TracebackType
  13. from typing import (
  14. TYPE_CHECKING,
  15. Any,
  16. Callable,
  17. Generic,
  18. Iterator,
  19. List,
  20. Optional,
  21. Type,
  22. TypeVar,
  23. cast,
  24. overload,
  25. )
  26. from unittest import IsolatedAsyncioTestCase, mock
  27. from aiosignal import Signal
  28. from multidict import CIMultiDict, CIMultiDictProxy
  29. from yarl import URL
  30. import aiohttp
  31. from aiohttp.client import (
  32. _RequestContextManager,
  33. _RequestOptions,
  34. _WSRequestContextManager,
  35. )
  36. from . import ClientSession, hdrs
  37. from .abc import AbstractCookieJar
  38. from .client_reqrep import ClientResponse
  39. from .client_ws import ClientWebSocketResponse
  40. from .helpers import sentinel
  41. from .http import HttpVersion, RawRequestMessage
  42. from .streams import EMPTY_PAYLOAD, StreamReader
  43. from .typedefs import StrOrURL
  44. from .web import (
  45. Application,
  46. AppRunner,
  47. BaseRequest,
  48. BaseRunner,
  49. Request,
  50. Server,
  51. ServerRunner,
  52. SockSite,
  53. UrlMappingMatchInfo,
  54. )
  55. from .web_protocol import _RequestHandler
  56. if TYPE_CHECKING:
  57. from ssl import SSLContext
  58. else:
  59. SSLContext = None
  60. if sys.version_info >= (3, 11) and TYPE_CHECKING:
  61. from typing import Unpack
  62. if sys.version_info >= (3, 11):
  63. from typing import Self
  64. else:
  65. Self = Any
  66. _ApplicationNone = TypeVar("_ApplicationNone", Application, None)
  67. _Request = TypeVar("_Request", bound=BaseRequest)
  68. REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
  69. def get_unused_port_socket(
  70. host: str, family: socket.AddressFamily = socket.AF_INET
  71. ) -> socket.socket:
  72. return get_port_socket(host, 0, family)
  73. def get_port_socket(
  74. host: str, port: int, family: socket.AddressFamily
  75. ) -> socket.socket:
  76. s = socket.socket(family, socket.SOCK_STREAM)
  77. if REUSE_ADDRESS:
  78. # Windows has different semantics for SO_REUSEADDR,
  79. # so don't set it. Ref:
  80. # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
  81. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  82. s.bind((host, port))
  83. return s
  84. def unused_port() -> int:
  85. """Return a port that is unused on the current host."""
  86. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  87. s.bind(("127.0.0.1", 0))
  88. return cast(int, s.getsockname()[1])
  89. class BaseTestServer(ABC):
  90. __test__ = False
  91. def __init__(
  92. self,
  93. *,
  94. scheme: str = "",
  95. loop: Optional[asyncio.AbstractEventLoop] = None,
  96. host: str = "127.0.0.1",
  97. port: Optional[int] = None,
  98. skip_url_asserts: bool = False,
  99. socket_factory: Callable[
  100. [str, int, socket.AddressFamily], socket.socket
  101. ] = get_port_socket,
  102. **kwargs: Any,
  103. ) -> None:
  104. self._loop = loop
  105. self.runner: Optional[BaseRunner] = None
  106. self._root: Optional[URL] = None
  107. self.host = host
  108. self.port = port
  109. self._closed = False
  110. self.scheme = scheme
  111. self.skip_url_asserts = skip_url_asserts
  112. self.socket_factory = socket_factory
  113. async def start_server(
  114. self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
  115. ) -> None:
  116. if self.runner:
  117. return
  118. self._loop = loop
  119. self._ssl = kwargs.pop("ssl", None)
  120. self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
  121. await self.runner.setup()
  122. if not self.port:
  123. self.port = 0
  124. absolute_host = self.host
  125. try:
  126. version = ipaddress.ip_address(self.host).version
  127. except ValueError:
  128. version = 4
  129. if version == 6:
  130. absolute_host = f"[{self.host}]"
  131. family = socket.AF_INET6 if version == 6 else socket.AF_INET
  132. _sock = self.socket_factory(self.host, self.port, family)
  133. self.host, self.port = _sock.getsockname()[:2]
  134. site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
  135. await site.start()
  136. server = site._server
  137. assert server is not None
  138. sockets = server.sockets # type: ignore[attr-defined]
  139. assert sockets is not None
  140. self.port = sockets[0].getsockname()[1]
  141. if not self.scheme:
  142. self.scheme = "https" if self._ssl else "http"
  143. self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
  144. @abstractmethod # pragma: no cover
  145. async def _make_runner(self, **kwargs: Any) -> BaseRunner:
  146. pass
  147. def make_url(self, path: StrOrURL) -> URL:
  148. assert self._root is not None
  149. url = URL(path)
  150. if not self.skip_url_asserts:
  151. assert not url.absolute
  152. return self._root.join(url)
  153. else:
  154. return URL(str(self._root) + str(path))
  155. @property
  156. def started(self) -> bool:
  157. return self.runner is not None
  158. @property
  159. def closed(self) -> bool:
  160. return self._closed
  161. @property
  162. def handler(self) -> Server:
  163. # for backward compatibility
  164. # web.Server instance
  165. runner = self.runner
  166. assert runner is not None
  167. assert runner.server is not None
  168. return runner.server
  169. async def close(self) -> None:
  170. """Close all fixtures created by the test client.
  171. After that point, the TestClient is no longer usable.
  172. This is an idempotent function: running close multiple times
  173. will not have any additional effects.
  174. close is also run when the object is garbage collected, and on
  175. exit when used as a context manager.
  176. """
  177. if self.started and not self.closed:
  178. assert self.runner is not None
  179. await self.runner.cleanup()
  180. self._root = None
  181. self.port = None
  182. self._closed = True
  183. def __enter__(self) -> None:
  184. raise TypeError("Use async with instead")
  185. def __exit__(
  186. self,
  187. exc_type: Optional[Type[BaseException]],
  188. exc_value: Optional[BaseException],
  189. traceback: Optional[TracebackType],
  190. ) -> None:
  191. # __exit__ should exist in pair with __enter__ but never executed
  192. pass # pragma: no cover
  193. async def __aenter__(self) -> "BaseTestServer":
  194. await self.start_server(loop=self._loop)
  195. return self
  196. async def __aexit__(
  197. self,
  198. exc_type: Optional[Type[BaseException]],
  199. exc_value: Optional[BaseException],
  200. traceback: Optional[TracebackType],
  201. ) -> None:
  202. await self.close()
  203. class TestServer(BaseTestServer):
  204. def __init__(
  205. self,
  206. app: Application,
  207. *,
  208. scheme: str = "",
  209. host: str = "127.0.0.1",
  210. port: Optional[int] = None,
  211. **kwargs: Any,
  212. ):
  213. self.app = app
  214. super().__init__(scheme=scheme, host=host, port=port, **kwargs)
  215. async def _make_runner(self, **kwargs: Any) -> BaseRunner:
  216. return AppRunner(self.app, **kwargs)
  217. class RawTestServer(BaseTestServer):
  218. def __init__(
  219. self,
  220. handler: _RequestHandler,
  221. *,
  222. scheme: str = "",
  223. host: str = "127.0.0.1",
  224. port: Optional[int] = None,
  225. **kwargs: Any,
  226. ) -> None:
  227. self._handler = handler
  228. super().__init__(scheme=scheme, host=host, port=port, **kwargs)
  229. async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
  230. srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
  231. return ServerRunner(srv, debug=debug, **kwargs)
  232. class TestClient(Generic[_Request, _ApplicationNone]):
  233. """
  234. A test client implementation.
  235. To write functional tests for aiohttp based servers.
  236. """
  237. __test__ = False
  238. @overload
  239. def __init__(
  240. self: "TestClient[Request, Application]",
  241. server: TestServer,
  242. *,
  243. cookie_jar: Optional[AbstractCookieJar] = None,
  244. **kwargs: Any,
  245. ) -> None: ...
  246. @overload
  247. def __init__(
  248. self: "TestClient[_Request, None]",
  249. server: BaseTestServer,
  250. *,
  251. cookie_jar: Optional[AbstractCookieJar] = None,
  252. **kwargs: Any,
  253. ) -> None: ...
  254. def __init__(
  255. self,
  256. server: BaseTestServer,
  257. *,
  258. cookie_jar: Optional[AbstractCookieJar] = None,
  259. loop: Optional[asyncio.AbstractEventLoop] = None,
  260. **kwargs: Any,
  261. ) -> None:
  262. if not isinstance(server, BaseTestServer):
  263. raise TypeError(
  264. "server must be TestServer instance, found type: %r" % type(server)
  265. )
  266. self._server = server
  267. self._loop = loop
  268. if cookie_jar is None:
  269. cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
  270. self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
  271. self._session._retry_connection = False
  272. self._closed = False
  273. self._responses: List[ClientResponse] = []
  274. self._websockets: List[ClientWebSocketResponse] = []
  275. async def start_server(self) -> None:
  276. await self._server.start_server(loop=self._loop)
  277. @property
  278. def host(self) -> str:
  279. return self._server.host
  280. @property
  281. def port(self) -> Optional[int]:
  282. return self._server.port
  283. @property
  284. def server(self) -> BaseTestServer:
  285. return self._server
  286. @property
  287. def app(self) -> _ApplicationNone:
  288. return getattr(self._server, "app", None) # type: ignore[return-value]
  289. @property
  290. def session(self) -> ClientSession:
  291. """An internal aiohttp.ClientSession.
  292. Unlike the methods on the TestClient, client session requests
  293. do not automatically include the host in the url queried, and
  294. will require an absolute path to the resource.
  295. """
  296. return self._session
  297. def make_url(self, path: StrOrURL) -> URL:
  298. return self._server.make_url(path)
  299. async def _request(
  300. self, method: str, path: StrOrURL, **kwargs: Any
  301. ) -> ClientResponse:
  302. resp = await self._session.request(method, self.make_url(path), **kwargs)
  303. # save it to close later
  304. self._responses.append(resp)
  305. return resp
  306. if sys.version_info >= (3, 11) and TYPE_CHECKING:
  307. def request(
  308. self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
  309. ) -> _RequestContextManager: ...
  310. def get(
  311. self,
  312. path: StrOrURL,
  313. **kwargs: Unpack[_RequestOptions],
  314. ) -> _RequestContextManager: ...
  315. def options(
  316. self,
  317. path: StrOrURL,
  318. **kwargs: Unpack[_RequestOptions],
  319. ) -> _RequestContextManager: ...
  320. def head(
  321. self,
  322. path: StrOrURL,
  323. **kwargs: Unpack[_RequestOptions],
  324. ) -> _RequestContextManager: ...
  325. def post(
  326. self,
  327. path: StrOrURL,
  328. **kwargs: Unpack[_RequestOptions],
  329. ) -> _RequestContextManager: ...
  330. def put(
  331. self,
  332. path: StrOrURL,
  333. **kwargs: Unpack[_RequestOptions],
  334. ) -> _RequestContextManager: ...
  335. def patch(
  336. self,
  337. path: StrOrURL,
  338. **kwargs: Unpack[_RequestOptions],
  339. ) -> _RequestContextManager: ...
  340. def delete(
  341. self,
  342. path: StrOrURL,
  343. **kwargs: Unpack[_RequestOptions],
  344. ) -> _RequestContextManager: ...
  345. else:
  346. def request(
  347. self, method: str, path: StrOrURL, **kwargs: Any
  348. ) -> _RequestContextManager:
  349. """Routes a request to tested http server.
  350. The interface is identical to aiohttp.ClientSession.request,
  351. except the loop kwarg is overridden by the instance used by the
  352. test server.
  353. """
  354. return _RequestContextManager(self._request(method, path, **kwargs))
  355. def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  356. """Perform an HTTP GET request."""
  357. return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
  358. def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  359. """Perform an HTTP POST request."""
  360. return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
  361. def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  362. """Perform an HTTP OPTIONS request."""
  363. return _RequestContextManager(
  364. self._request(hdrs.METH_OPTIONS, path, **kwargs)
  365. )
  366. def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  367. """Perform an HTTP HEAD request."""
  368. return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
  369. def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  370. """Perform an HTTP PUT request."""
  371. return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
  372. def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  373. """Perform an HTTP PATCH request."""
  374. return _RequestContextManager(
  375. self._request(hdrs.METH_PATCH, path, **kwargs)
  376. )
  377. def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
  378. """Perform an HTTP PATCH request."""
  379. return _RequestContextManager(
  380. self._request(hdrs.METH_DELETE, path, **kwargs)
  381. )
  382. def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
  383. """Initiate websocket connection.
  384. The api corresponds to aiohttp.ClientSession.ws_connect.
  385. """
  386. return _WSRequestContextManager(self._ws_connect(path, **kwargs))
  387. async def _ws_connect(
  388. self, path: StrOrURL, **kwargs: Any
  389. ) -> ClientWebSocketResponse:
  390. ws = await self._session.ws_connect(self.make_url(path), **kwargs)
  391. self._websockets.append(ws)
  392. return ws
  393. async def close(self) -> None:
  394. """Close all fixtures created by the test client.
  395. After that point, the TestClient is no longer usable.
  396. This is an idempotent function: running close multiple times
  397. will not have any additional effects.
  398. close is also run on exit when used as a(n) (asynchronous)
  399. context manager.
  400. """
  401. if not self._closed:
  402. for resp in self._responses:
  403. resp.close()
  404. for ws in self._websockets:
  405. await ws.close()
  406. await self._session.close()
  407. await self._server.close()
  408. self._closed = True
  409. def __enter__(self) -> None:
  410. raise TypeError("Use async with instead")
  411. def __exit__(
  412. self,
  413. exc_type: Optional[Type[BaseException]],
  414. exc: Optional[BaseException],
  415. tb: Optional[TracebackType],
  416. ) -> None:
  417. # __exit__ should exist in pair with __enter__ but never executed
  418. pass # pragma: no cover
  419. async def __aenter__(self) -> Self:
  420. await self.start_server()
  421. return self
  422. async def __aexit__(
  423. self,
  424. exc_type: Optional[Type[BaseException]],
  425. exc: Optional[BaseException],
  426. tb: Optional[TracebackType],
  427. ) -> None:
  428. await self.close()
  429. class AioHTTPTestCase(IsolatedAsyncioTestCase):
  430. """A base class to allow for unittest web applications using aiohttp.
  431. Provides the following:
  432. * self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
  433. * self.loop (asyncio.BaseEventLoop): the event loop in which the
  434. application and server are running.
  435. * self.app (aiohttp.web.Application): the application returned by
  436. self.get_application()
  437. Note that the TestClient's methods are asynchronous: you have to
  438. execute function on the test client using asynchronous methods.
  439. """
  440. async def get_application(self) -> Application:
  441. """Get application.
  442. This method should be overridden
  443. to return the aiohttp.web.Application
  444. object to test.
  445. """
  446. return self.get_app()
  447. def get_app(self) -> Application:
  448. """Obsolete method used to constructing web application.
  449. Use .get_application() coroutine instead.
  450. """
  451. raise RuntimeError("Did you forget to define get_application()?")
  452. async def asyncSetUp(self) -> None:
  453. self.loop = asyncio.get_running_loop()
  454. return await self.setUpAsync()
  455. async def setUpAsync(self) -> None:
  456. self.app = await self.get_application()
  457. self.server = await self.get_server(self.app)
  458. self.client = await self.get_client(self.server)
  459. await self.client.start_server()
  460. async def asyncTearDown(self) -> None:
  461. return await self.tearDownAsync()
  462. async def tearDownAsync(self) -> None:
  463. await self.client.close()
  464. async def get_server(self, app: Application) -> TestServer:
  465. """Return a TestServer instance."""
  466. return TestServer(app, loop=self.loop)
  467. async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
  468. """Return a TestClient instance."""
  469. return TestClient(server, loop=self.loop)
  470. def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
  471. """
  472. A decorator dedicated to use with asynchronous AioHTTPTestCase test methods.
  473. In 3.8+, this does nothing.
  474. """
  475. warnings.warn(
  476. "Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+",
  477. DeprecationWarning,
  478. stacklevel=2,
  479. )
  480. return func
  481. _LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
  482. @contextlib.contextmanager
  483. def loop_context(
  484. loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
  485. ) -> Iterator[asyncio.AbstractEventLoop]:
  486. """A contextmanager that creates an event_loop, for test purposes.
  487. Handles the creation and cleanup of a test loop.
  488. """
  489. loop = setup_test_loop(loop_factory)
  490. yield loop
  491. teardown_test_loop(loop, fast=fast)
  492. def setup_test_loop(
  493. loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
  494. ) -> asyncio.AbstractEventLoop:
  495. """Create and return an asyncio.BaseEventLoop instance.
  496. The caller should also call teardown_test_loop,
  497. once they are done with the loop.
  498. """
  499. loop = loop_factory()
  500. asyncio.set_event_loop(loop)
  501. return loop
  502. def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
  503. """Teardown and cleanup an event_loop created by setup_test_loop."""
  504. closed = loop.is_closed()
  505. if not closed:
  506. loop.call_soon(loop.stop)
  507. loop.run_forever()
  508. loop.close()
  509. if not fast:
  510. gc.collect()
  511. asyncio.set_event_loop(None)
  512. def _create_app_mock() -> mock.MagicMock:
  513. def get_dict(app: Any, key: str) -> Any:
  514. return app.__app_dict[key]
  515. def set_dict(app: Any, key: str, value: Any) -> None:
  516. app.__app_dict[key] = value
  517. app = mock.MagicMock(spec=Application)
  518. app.__app_dict = {}
  519. app.__getitem__ = get_dict
  520. app.__setitem__ = set_dict
  521. app._debug = False
  522. app.on_response_prepare = Signal(app)
  523. app.on_response_prepare.freeze()
  524. return app
  525. def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
  526. transport = mock.Mock()
  527. def get_extra_info(key: str) -> Optional[SSLContext]:
  528. if key == "sslcontext":
  529. return sslcontext
  530. else:
  531. return None
  532. transport.get_extra_info.side_effect = get_extra_info
  533. return transport
  534. def make_mocked_request(
  535. method: str,
  536. path: str,
  537. headers: Any = None,
  538. *,
  539. match_info: Any = sentinel,
  540. version: HttpVersion = HttpVersion(1, 1),
  541. closing: bool = False,
  542. app: Any = None,
  543. writer: Any = sentinel,
  544. protocol: Any = sentinel,
  545. transport: Any = sentinel,
  546. payload: StreamReader = EMPTY_PAYLOAD,
  547. sslcontext: Optional[SSLContext] = None,
  548. client_max_size: int = 1024**2,
  549. loop: Any = ...,
  550. ) -> Request:
  551. """Creates mocked web.Request testing purposes.
  552. Useful in unit tests, when spinning full web server is overkill or
  553. specific conditions and errors are hard to trigger.
  554. """
  555. task = mock.Mock()
  556. if loop is ...:
  557. # no loop passed, try to get the current one if
  558. # its is running as we need a real loop to create
  559. # executor jobs to be able to do testing
  560. # with a real executor
  561. try:
  562. loop = asyncio.get_running_loop()
  563. except RuntimeError:
  564. loop = mock.Mock()
  565. loop.create_future.return_value = ()
  566. if version < HttpVersion(1, 1):
  567. closing = True
  568. if headers:
  569. headers = CIMultiDictProxy(CIMultiDict(headers))
  570. raw_hdrs = tuple(
  571. (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
  572. )
  573. else:
  574. headers = CIMultiDictProxy(CIMultiDict())
  575. raw_hdrs = ()
  576. chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
  577. message = RawRequestMessage(
  578. method,
  579. path,
  580. version,
  581. headers,
  582. raw_hdrs,
  583. closing,
  584. None,
  585. False,
  586. chunked,
  587. URL(path),
  588. )
  589. if app is None:
  590. app = _create_app_mock()
  591. if transport is sentinel:
  592. transport = _create_transport(sslcontext)
  593. if protocol is sentinel:
  594. protocol = mock.Mock()
  595. protocol.transport = transport
  596. if writer is sentinel:
  597. writer = mock.Mock()
  598. writer.write_headers = make_mocked_coro(None)
  599. writer.write = make_mocked_coro(None)
  600. writer.write_eof = make_mocked_coro(None)
  601. writer.drain = make_mocked_coro(None)
  602. writer.transport = transport
  603. protocol.transport = transport
  604. protocol.writer = writer
  605. req = Request(
  606. message, payload, protocol, writer, task, loop, client_max_size=client_max_size
  607. )
  608. match_info = UrlMappingMatchInfo(
  609. {} if match_info is sentinel else match_info, mock.Mock()
  610. )
  611. match_info.add_app(app)
  612. req._match_info = match_info
  613. return req
  614. def make_mocked_coro(
  615. return_value: Any = sentinel, raise_exception: Any = sentinel
  616. ) -> Any:
  617. """Creates a coroutine mock."""
  618. async def mock_coro(*args: Any, **kwargs: Any) -> Any:
  619. if raise_exception is not sentinel:
  620. raise raise_exception
  621. if not inspect.isawaitable(return_value):
  622. return return_value
  623. await return_value
  624. return mock.Mock(wraps=mock_coro)