worker.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. """Async gunicorn worker for aiohttp.web"""
  2. import asyncio
  3. import inspect
  4. import os
  5. import re
  6. import signal
  7. import sys
  8. from types import FrameType
  9. from typing import TYPE_CHECKING, Any, Optional
  10. from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
  11. from gunicorn.workers import base
  12. from aiohttp import web
  13. from .helpers import set_result
  14. from .web_app import Application
  15. from .web_log import AccessLogger
  16. if TYPE_CHECKING:
  17. import ssl
  18. SSLContext = ssl.SSLContext
  19. else:
  20. try:
  21. import ssl
  22. SSLContext = ssl.SSLContext
  23. except ImportError: # pragma: no cover
  24. ssl = None # type: ignore[assignment]
  25. SSLContext = object # type: ignore[misc,assignment]
  26. __all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker")
  27. class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
  28. DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
  29. DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
  30. def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
  31. super().__init__(*args, **kw)
  32. self._task: Optional[asyncio.Task[None]] = None
  33. self.exit_code = 0
  34. self._notify_waiter: Optional[asyncio.Future[bool]] = None
  35. def init_process(self) -> None:
  36. # create new event_loop after fork
  37. asyncio.get_event_loop().close()
  38. self.loop = asyncio.new_event_loop()
  39. asyncio.set_event_loop(self.loop)
  40. super().init_process()
  41. def run(self) -> None:
  42. self._task = self.loop.create_task(self._run())
  43. try: # ignore all finalization problems
  44. self.loop.run_until_complete(self._task)
  45. except Exception:
  46. self.log.exception("Exception in gunicorn worker")
  47. self.loop.run_until_complete(self.loop.shutdown_asyncgens())
  48. self.loop.close()
  49. sys.exit(self.exit_code)
  50. async def _run(self) -> None:
  51. runner = None
  52. if isinstance(self.wsgi, Application):
  53. app = self.wsgi
  54. elif inspect.iscoroutinefunction(self.wsgi) or (
  55. sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi)
  56. ):
  57. wsgi = await self.wsgi()
  58. if isinstance(wsgi, web.AppRunner):
  59. runner = wsgi
  60. app = runner.app
  61. else:
  62. app = wsgi
  63. else:
  64. raise RuntimeError(
  65. "wsgi app should be either Application or "
  66. "async function returning Application, got {}".format(self.wsgi)
  67. )
  68. if runner is None:
  69. access_log = self.log.access_log if self.cfg.accesslog else None
  70. runner = web.AppRunner(
  71. app,
  72. logger=self.log,
  73. keepalive_timeout=self.cfg.keepalive,
  74. access_log=access_log,
  75. access_log_format=self._get_valid_log_format(
  76. self.cfg.access_log_format
  77. ),
  78. shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
  79. )
  80. await runner.setup()
  81. ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
  82. runner = runner
  83. assert runner is not None
  84. server = runner.server
  85. assert server is not None
  86. for sock in self.sockets:
  87. site = web.SockSite(
  88. runner,
  89. sock,
  90. ssl_context=ctx,
  91. )
  92. await site.start()
  93. # If our parent changed then we shut down.
  94. pid = os.getpid()
  95. try:
  96. while self.alive: # type: ignore[has-type]
  97. self.notify()
  98. cnt = server.requests_count
  99. if self.max_requests and cnt > self.max_requests:
  100. self.alive = False
  101. self.log.info("Max requests, shutting down: %s", self)
  102. elif pid == os.getpid() and self.ppid != os.getppid():
  103. self.alive = False
  104. self.log.info("Parent changed, shutting down: %s", self)
  105. else:
  106. await self._wait_next_notify()
  107. except BaseException:
  108. pass
  109. await runner.cleanup()
  110. def _wait_next_notify(self) -> "asyncio.Future[bool]":
  111. self._notify_waiter_done()
  112. loop = self.loop
  113. assert loop is not None
  114. self._notify_waiter = waiter = loop.create_future()
  115. self.loop.call_later(1.0, self._notify_waiter_done, waiter)
  116. return waiter
  117. def _notify_waiter_done(
  118. self, waiter: Optional["asyncio.Future[bool]"] = None
  119. ) -> None:
  120. if waiter is None:
  121. waiter = self._notify_waiter
  122. if waiter is not None:
  123. set_result(waiter, True)
  124. if waiter is self._notify_waiter:
  125. self._notify_waiter = None
  126. def init_signals(self) -> None:
  127. # Set up signals through the event loop API.
  128. self.loop.add_signal_handler(
  129. signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
  130. )
  131. self.loop.add_signal_handler(
  132. signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
  133. )
  134. self.loop.add_signal_handler(
  135. signal.SIGINT, self.handle_quit, signal.SIGINT, None
  136. )
  137. self.loop.add_signal_handler(
  138. signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
  139. )
  140. self.loop.add_signal_handler(
  141. signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
  142. )
  143. self.loop.add_signal_handler(
  144. signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
  145. )
  146. # Don't let SIGTERM and SIGUSR1 disturb active requests
  147. # by interrupting system calls
  148. signal.siginterrupt(signal.SIGTERM, False)
  149. signal.siginterrupt(signal.SIGUSR1, False)
  150. # Reset signals so Gunicorn doesn't swallow subprocess return codes
  151. # See: https://github.com/aio-libs/aiohttp/issues/6130
  152. def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None:
  153. self.alive = False
  154. # worker_int callback
  155. self.cfg.worker_int(self)
  156. # wakeup closing process
  157. self._notify_waiter_done()
  158. def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None:
  159. self.alive = False
  160. self.exit_code = 1
  161. self.cfg.worker_abort(self)
  162. sys.exit(1)
  163. @staticmethod
  164. def _create_ssl_context(cfg: Any) -> "SSLContext":
  165. """Creates SSLContext instance for usage in asyncio.create_server.
  166. See ssl.SSLSocket.__init__ for more details.
  167. """
  168. if ssl is None: # pragma: no cover
  169. raise RuntimeError("SSL is not supported.")
  170. ctx = ssl.SSLContext(cfg.ssl_version)
  171. ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
  172. ctx.verify_mode = cfg.cert_reqs
  173. if cfg.ca_certs:
  174. ctx.load_verify_locations(cfg.ca_certs)
  175. if cfg.ciphers:
  176. ctx.set_ciphers(cfg.ciphers)
  177. return ctx
  178. def _get_valid_log_format(self, source_format: str) -> str:
  179. if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
  180. return self.DEFAULT_AIOHTTP_LOG_FORMAT
  181. elif re.search(r"%\([^\)]+\)", source_format):
  182. raise ValueError(
  183. "Gunicorn's style options in form of `%(name)s` are not "
  184. "supported for the log formatting. Please use aiohttp's "
  185. "format specification to configure access log formatting: "
  186. "http://docs.aiohttp.org/en/stable/logging.html"
  187. "#format-specification"
  188. )
  189. else:
  190. return source_format
  191. class GunicornUVLoopWebWorker(GunicornWebWorker):
  192. def init_process(self) -> None:
  193. import uvloop
  194. # Close any existing event loop before setting a
  195. # new policy.
  196. asyncio.get_event_loop().close()
  197. # Setup uvloop policy, so that every
  198. # asyncio.get_event_loop() will create an instance
  199. # of uvloop event loop.
  200. asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
  201. super().init_process()