web_server.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """Low level HTTP server."""
  2. import asyncio
  3. from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
  4. from .abc import AbstractStreamWriter
  5. from .http_parser import RawRequestMessage
  6. from .streams import StreamReader
  7. from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
  8. from .web_request import BaseRequest
  9. __all__ = ("Server",)
  10. class Server:
  11. def __init__(
  12. self,
  13. handler: _RequestHandler,
  14. *,
  15. request_factory: Optional[_RequestFactory] = None,
  16. handler_cancellation: bool = False,
  17. loop: Optional[asyncio.AbstractEventLoop] = None,
  18. **kwargs: Any,
  19. ) -> None:
  20. self._loop = loop or asyncio.get_running_loop()
  21. self._connections: Dict[RequestHandler, asyncio.Transport] = {}
  22. self._kwargs = kwargs
  23. # requests_count is the number of requests being processed by the server
  24. # for the lifetime of the server.
  25. self.requests_count = 0
  26. self.request_handler = handler
  27. self.request_factory = request_factory or self._make_request
  28. self.handler_cancellation = handler_cancellation
  29. @property
  30. def connections(self) -> List[RequestHandler]:
  31. return list(self._connections.keys())
  32. def connection_made(
  33. self, handler: RequestHandler, transport: asyncio.Transport
  34. ) -> None:
  35. self._connections[handler] = transport
  36. def connection_lost(
  37. self, handler: RequestHandler, exc: Optional[BaseException] = None
  38. ) -> None:
  39. if handler in self._connections:
  40. if handler._task_handler:
  41. handler._task_handler.add_done_callback(
  42. lambda f: self._connections.pop(handler, None)
  43. )
  44. else:
  45. del self._connections[handler]
  46. def _make_request(
  47. self,
  48. message: RawRequestMessage,
  49. payload: StreamReader,
  50. protocol: RequestHandler,
  51. writer: AbstractStreamWriter,
  52. task: "asyncio.Task[None]",
  53. ) -> BaseRequest:
  54. return BaseRequest(message, payload, protocol, writer, task, self._loop)
  55. def pre_shutdown(self) -> None:
  56. for conn in self._connections:
  57. conn.close()
  58. async def shutdown(self, timeout: Optional[float] = None) -> None:
  59. coros = (conn.shutdown(timeout) for conn in self._connections)
  60. await asyncio.gather(*coros)
  61. self._connections.clear()
  62. def __call__(self) -> RequestHandler:
  63. try:
  64. return RequestHandler(self, loop=self._loop, **self._kwargs)
  65. except TypeError:
  66. # Failsafe creation: remove all custom handler_args
  67. kwargs = {
  68. k: v
  69. for k, v in self._kwargs.items()
  70. if k in ["debug", "access_log_class"]
  71. }
  72. return RequestHandler(self, loop=self._loop, **kwargs)