123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- import asyncio
- from typing import Optional, cast
- from .client_exceptions import ClientConnectionResetError
- from .helpers import set_exception
- from .tcp_helpers import tcp_nodelay
- class BaseProtocol(asyncio.Protocol):
- __slots__ = (
- "_loop",
- "_paused",
- "_drain_waiter",
- "_connection_lost",
- "_reading_paused",
- "transport",
- )
- def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
- self._loop: asyncio.AbstractEventLoop = loop
- self._paused = False
- self._drain_waiter: Optional[asyncio.Future[None]] = None
- self._reading_paused = False
- self.transport: Optional[asyncio.Transport] = None
- @property
- def connected(self) -> bool:
- """Return True if the connection is open."""
- return self.transport is not None
- @property
- def writing_paused(self) -> bool:
- return self._paused
- def pause_writing(self) -> None:
- assert not self._paused
- self._paused = True
- def resume_writing(self) -> None:
- assert self._paused
- self._paused = False
- waiter = self._drain_waiter
- if waiter is not None:
- self._drain_waiter = None
- if not waiter.done():
- waiter.set_result(None)
- def pause_reading(self) -> None:
- if not self._reading_paused and self.transport is not None:
- try:
- self.transport.pause_reading()
- except (AttributeError, NotImplementedError, RuntimeError):
- pass
- self._reading_paused = True
- def resume_reading(self) -> None:
- if self._reading_paused and self.transport is not None:
- try:
- self.transport.resume_reading()
- except (AttributeError, NotImplementedError, RuntimeError):
- pass
- self._reading_paused = False
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
- tr = cast(asyncio.Transport, transport)
- tcp_nodelay(tr, True)
- self.transport = tr
- def connection_lost(self, exc: Optional[BaseException]) -> None:
- # Wake up the writer if currently paused.
- self.transport = None
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- return
- self._drain_waiter = None
- if waiter.done():
- return
- if exc is None:
- waiter.set_result(None)
- else:
- set_exception(
- waiter,
- ConnectionError("Connection lost"),
- exc,
- )
- async def _drain_helper(self) -> None:
- if self.transport is None:
- raise ClientConnectionResetError("Connection lost")
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- waiter = self._loop.create_future()
- self._drain_waiter = waiter
- await asyncio.shield(waiter)
|