web_runner.py 12 KB


  1. import asyncio
  2. import signal
  3. import socket
  4. import warnings
  5. from abc import ABC, abstractmethod
  6. from typing import TYPE_CHECKING, Any, List, Optional, Set
  7. from yarl import URL
  8. from .typedefs import PathLike
  9. from .web_app import Application
  10. from .web_server import Server
  11. if TYPE_CHECKING:
  12. from ssl import SSLContext
  13. else:
  14. try:
  15. from ssl import SSLContext
  16. except ImportError: # pragma: no cover
  17. SSLContext = object # type: ignore[misc,assignment]
  18. __all__ = (
  19. "BaseSite",
  20. "TCPSite",
  21. "UnixSite",
  22. "NamedPipeSite",
  23. "SockSite",
  24. "BaseRunner",
  25. "AppRunner",
  26. "ServerRunner",
  27. "GracefulExit",
  28. )
  29. class GracefulExit(SystemExit):
  30. code = 1
  31. def _raise_graceful_exit() -> None:
  32. raise GracefulExit()
  33. class BaseSite(ABC):
  34. __slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
  35. def __init__(
  36. self,
  37. runner: "BaseRunner",
  38. *,
  39. shutdown_timeout: float = 60.0,
  40. ssl_context: Optional[SSLContext] = None,
  41. backlog: int = 128,
  42. ) -> None:
  43. if runner.server is None:
  44. raise RuntimeError("Call runner.setup() before making a site")
  45. if shutdown_timeout != 60.0:
  46. msg = "shutdown_timeout should be set on BaseRunner"
  47. warnings.warn(msg, DeprecationWarning, stacklevel=2)
  48. runner._shutdown_timeout = shutdown_timeout
  49. self._runner = runner
  50. self._ssl_context = ssl_context
  51. self._backlog = backlog
  52. self._server: Optional[asyncio.AbstractServer] = None
  53. @property
  54. @abstractmethod
  55. def name(self) -> str:
  56. pass # pragma: no cover
  57. @abstractmethod
  58. async def start(self) -> None:
  59. self._runner._reg_site(self)
  60. async def stop(self) -> None:
  61. self._runner._check_site(self)
  62. if self._server is not None: # Maybe not started yet
  63. self._server.close()
  64. self._runner._unreg_site(self)
  65. class TCPSite(BaseSite):
  66. __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
  67. def __init__(
  68. self,
  69. runner: "BaseRunner",
  70. host: Optional[str] = None,
  71. port: Optional[int] = None,
  72. *,
  73. shutdown_timeout: float = 60.0,
  74. ssl_context: Optional[SSLContext] = None,
  75. backlog: int = 128,
  76. reuse_address: Optional[bool] = None,
  77. reuse_port: Optional[bool] = None,
  78. ) -> None:
  79. super().__init__(
  80. runner,
  81. shutdown_timeout=shutdown_timeout,
  82. ssl_context=ssl_context,
  83. backlog=backlog,
  84. )
  85. self._host = host
  86. if port is None:
  87. port = 8443 if self._ssl_context else 8080
  88. self._port = port
  89. self._reuse_address = reuse_address
  90. self._reuse_port = reuse_port
  91. @property
  92. def name(self) -> str:
  93. scheme = "https" if self._ssl_context else "http"
  94. host = "0.0.0.0" if not self._host else self._host
  95. return str(URL.build(scheme=scheme, host=host, port=self._port))
  96. async def start(self) -> None:
  97. await super().start()
  98. loop = asyncio.get_event_loop()
  99. server = self._runner.server
  100. assert server is not None
  101. self._server = await loop.create_server(
  102. server,
  103. self._host,
  104. self._port,
  105. ssl=self._ssl_context,
  106. backlog=self._backlog,
  107. reuse_address=self._reuse_address,
  108. reuse_port=self._reuse_port,
  109. )
  110. class UnixSite(BaseSite):
  111. __slots__ = ("_path",)
  112. def __init__(
  113. self,
  114. runner: "BaseRunner",
  115. path: PathLike,
  116. *,
  117. shutdown_timeout: float = 60.0,
  118. ssl_context: Optional[SSLContext] = None,
  119. backlog: int = 128,
  120. ) -> None:
  121. super().__init__(
  122. runner,
  123. shutdown_timeout=shutdown_timeout,
  124. ssl_context=ssl_context,
  125. backlog=backlog,
  126. )
  127. self._path = path
  128. @property
  129. def name(self) -> str:
  130. scheme = "https" if self._ssl_context else "http"
  131. return f"{scheme}://unix:{self._path}:"
  132. async def start(self) -> None:
  133. await super().start()
  134. loop = asyncio.get_event_loop()
  135. server = self._runner.server
  136. assert server is not None
  137. self._server = await loop.create_unix_server(
  138. server,
  139. self._path,
  140. ssl=self._ssl_context,
  141. backlog=self._backlog,
  142. )
  143. class NamedPipeSite(BaseSite):
  144. __slots__ = ("_path",)
  145. def __init__(
  146. self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0
  147. ) -> None:
  148. loop = asyncio.get_event_loop()
  149. if not isinstance(
  150. loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
  151. ):
  152. raise RuntimeError(
  153. "Named Pipes only available in proactor loop under windows"
  154. )
  155. super().__init__(runner, shutdown_timeout=shutdown_timeout)
  156. self._path = path
  157. @property
  158. def name(self) -> str:
  159. return self._path
  160. async def start(self) -> None:
  161. await super().start()
  162. loop = asyncio.get_event_loop()
  163. server = self._runner.server
  164. assert server is not None
  165. _server = await loop.start_serving_pipe( # type: ignore[attr-defined]
  166. server, self._path
  167. )
  168. self._server = _server[0]
  169. class SockSite(BaseSite):
  170. __slots__ = ("_sock", "_name")
  171. def __init__(
  172. self,
  173. runner: "BaseRunner",
  174. sock: socket.socket,
  175. *,
  176. shutdown_timeout: float = 60.0,
  177. ssl_context: Optional[SSLContext] = None,
  178. backlog: int = 128,
  179. ) -> None:
  180. super().__init__(
  181. runner,
  182. shutdown_timeout=shutdown_timeout,
  183. ssl_context=ssl_context,
  184. backlog=backlog,
  185. )
  186. self._sock = sock
  187. scheme = "https" if self._ssl_context else "http"
  188. if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
  189. name = f"{scheme}://unix:{sock.getsockname()}:"
  190. else:
  191. host, port = sock.getsockname()[:2]
  192. name = str(URL.build(scheme=scheme, host=host, port=port))
  193. self._name = name
  194. @property
  195. def name(self) -> str:
  196. return self._name
  197. async def start(self) -> None:
  198. await super().start()
  199. loop = asyncio.get_event_loop()
  200. server = self._runner.server
  201. assert server is not None
  202. self._server = await loop.create_server(
  203. server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
  204. )
  205. class BaseRunner(ABC):
  206. __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
  207. def __init__(
  208. self,
  209. *,
  210. handle_signals: bool = False,
  211. shutdown_timeout: float = 60.0,
  212. **kwargs: Any,
  213. ) -> None:
  214. self._handle_signals = handle_signals
  215. self._kwargs = kwargs
  216. self._server: Optional[Server] = None
  217. self._sites: List[BaseSite] = []
  218. self._shutdown_timeout = shutdown_timeout
  219. @property
  220. def server(self) -> Optional[Server]:
  221. return self._server
  222. @property
  223. def addresses(self) -> List[Any]:
  224. ret: List[Any] = []
  225. for site in self._sites:
  226. server = site._server
  227. if server is not None:
  228. sockets = server.sockets # type: ignore[attr-defined]
  229. if sockets is not None:
  230. for sock in sockets:
  231. ret.append(sock.getsockname())
  232. return ret
  233. @property
  234. def sites(self) -> Set[BaseSite]:
  235. return set(self._sites)
  236. async def setup(self) -> None:
  237. loop = asyncio.get_event_loop()
  238. if self._handle_signals:
  239. try:
  240. loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
  241. loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
  242. except NotImplementedError: # pragma: no cover
  243. # add_signal_handler is not implemented on Windows
  244. pass
  245. self._server = await self._make_server()
  246. @abstractmethod
  247. async def shutdown(self) -> None:
  248. """Call any shutdown hooks to help server close gracefully."""
  249. async def cleanup(self) -> None:
  250. # The loop over sites is intentional, an exception on gather()
  251. # leaves self._sites in unpredictable state.
  252. # The loop guaranties that a site is either deleted on success or
  253. # still present on failure
  254. for site in list(self._sites):
  255. await site.stop()
  256. if self._server: # If setup succeeded
  257. # Yield to event loop to ensure incoming requests prior to stopping the sites
  258. # have all started to be handled before we proceed to close idle connections.
  259. await asyncio.sleep(0)
  260. self._server.pre_shutdown()
  261. await self.shutdown()
  262. await self._server.shutdown(self._shutdown_timeout)
  263. await self._cleanup_server()
  264. self._server = None
  265. if self._handle_signals:
  266. loop = asyncio.get_running_loop()
  267. try:
  268. loop.remove_signal_handler(signal.SIGINT)
  269. loop.remove_signal_handler(signal.SIGTERM)
  270. except NotImplementedError: # pragma: no cover
  271. # remove_signal_handler is not implemented on Windows
  272. pass
  273. @abstractmethod
  274. async def _make_server(self) -> Server:
  275. pass # pragma: no cover
  276. @abstractmethod
  277. async def _cleanup_server(self) -> None:
  278. pass # pragma: no cover
  279. def _reg_site(self, site: BaseSite) -> None:
  280. if site in self._sites:
  281. raise RuntimeError(f"Site {site} is already registered in runner {self}")
  282. self._sites.append(site)
  283. def _check_site(self, site: BaseSite) -> None:
  284. if site not in self._sites:
  285. raise RuntimeError(f"Site {site} is not registered in runner {self}")
  286. def _unreg_site(self, site: BaseSite) -> None:
  287. if site not in self._sites:
  288. raise RuntimeError(f"Site {site} is not registered in runner {self}")
  289. self._sites.remove(site)
  290. class ServerRunner(BaseRunner):
  291. """Low-level web server runner"""
  292. __slots__ = ("_web_server",)
  293. def __init__(
  294. self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any
  295. ) -> None:
  296. super().__init__(handle_signals=handle_signals, **kwargs)
  297. self._web_server = web_server
  298. async def shutdown(self) -> None:
  299. pass
  300. async def _make_server(self) -> Server:
  301. return self._web_server
  302. async def _cleanup_server(self) -> None:
  303. pass
  304. class AppRunner(BaseRunner):
  305. """Web Application runner"""
  306. __slots__ = ("_app",)
  307. def __init__(
  308. self, app: Application, *, handle_signals: bool = False, **kwargs: Any
  309. ) -> None:
  310. super().__init__(handle_signals=handle_signals, **kwargs)
  311. if not isinstance(app, Application):
  312. raise TypeError(
  313. "The first argument should be web.Application "
  314. "instance, got {!r}".format(app)
  315. )
  316. self._app = app
  317. @property
  318. def app(self) -> Application:
  319. return self._app
  320. async def shutdown(self) -> None:
  321. await self._app.shutdown()
  322. async def _make_server(self) -> Server:
  323. loop = asyncio.get_event_loop()
  324. self._app._set_loop(loop)
  325. self._app.on_startup.freeze()
  326. await self._app.startup()
  327. self._app.freeze()
  328. return self._app._make_handler(loop=loop, **self._kwargs)
  329. async def _cleanup_server(self) -> None:
  330. await self._app.cleanup()