base_protocol.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import asyncio
  2. from typing import Optional, cast
  3. from .client_exceptions import ClientConnectionResetError
  4. from .helpers import set_exception
  5. from .tcp_helpers import tcp_nodelay
  6. class BaseProtocol(asyncio.Protocol):
  7. __slots__ = (
  8. "_loop",
  9. "_paused",
  10. "_drain_waiter",
  11. "_connection_lost",
  12. "_reading_paused",
  13. "transport",
  14. )
  15. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  16. self._loop: asyncio.AbstractEventLoop = loop
  17. self._paused = False
  18. self._drain_waiter: Optional[asyncio.Future[None]] = None
  19. self._reading_paused = False
  20. self.transport: Optional[asyncio.Transport] = None
  21. @property
  22. def connected(self) -> bool:
  23. """Return True if the connection is open."""
  24. return self.transport is not None
  25. @property
  26. def writing_paused(self) -> bool:
  27. return self._paused
  28. def pause_writing(self) -> None:
  29. assert not self._paused
  30. self._paused = True
  31. def resume_writing(self) -> None:
  32. assert self._paused
  33. self._paused = False
  34. waiter = self._drain_waiter
  35. if waiter is not None:
  36. self._drain_waiter = None
  37. if not waiter.done():
  38. waiter.set_result(None)
  39. def pause_reading(self) -> None:
  40. if not self._reading_paused and self.transport is not None:
  41. try:
  42. self.transport.pause_reading()
  43. except (AttributeError, NotImplementedError, RuntimeError):
  44. pass
  45. self._reading_paused = True
  46. def resume_reading(self) -> None:
  47. if self._reading_paused and self.transport is not None:
  48. try:
  49. self.transport.resume_reading()
  50. except (AttributeError, NotImplementedError, RuntimeError):
  51. pass
  52. self._reading_paused = False
  53. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  54. tr = cast(asyncio.Transport, transport)
  55. tcp_nodelay(tr, True)
  56. self.transport = tr
  57. def connection_lost(self, exc: Optional[BaseException]) -> None:
  58. # Wake up the writer if currently paused.
  59. self.transport = None
  60. if not self._paused:
  61. return
  62. waiter = self._drain_waiter
  63. if waiter is None:
  64. return
  65. self._drain_waiter = None
  66. if waiter.done():
  67. return
  68. if exc is None:
  69. waiter.set_result(None)
  70. else:
  71. set_exception(
  72. waiter,
  73. ConnectionError("Connection lost"),
  74. exc,
  75. )
  76. async def _drain_helper(self) -> None:
  77. if self.transport is None:
  78. raise ClientConnectionResetError("Connection lost")
  79. if not self._paused:
  80. return
  81. waiter = self._drain_waiter
  82. if waiter is None:
  83. waiter = self._loop.create_future()
  84. self._drain_waiter = waiter
  85. await asyncio.shield(waiter)