123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- from __future__ import annotations
- from collections.abc import Callable, Mapping, Sequence
- from dataclasses import dataclass
- from typing import Any, Generic, TypeVar
- from ..abc import (
- ByteReceiveStream,
- ByteSendStream,
- ByteStream,
- Listener,
- ObjectReceiveStream,
- ObjectSendStream,
- ObjectStream,
- TaskGroup,
- )
- T_Item = TypeVar("T_Item")
- T_Stream = TypeVar("T_Stream")
- @dataclass(eq=False)
- class StapledByteStream(ByteStream):
- """
- Combines two byte streams into a single, bidirectional byte stream.
- Extra attributes will be provided from both streams, with the receive stream
- providing the values in case of a conflict.
- :param ByteSendStream send_stream: the sending byte stream
- :param ByteReceiveStream receive_stream: the receiving byte stream
- """
- send_stream: ByteSendStream
- receive_stream: ByteReceiveStream
- async def receive(self, max_bytes: int = 65536) -> bytes:
- return await self.receive_stream.receive(max_bytes)
- async def send(self, item: bytes) -> None:
- await self.send_stream.send(item)
- async def send_eof(self) -> None:
- await self.send_stream.aclose()
- async def aclose(self) -> None:
- await self.send_stream.aclose()
- await self.receive_stream.aclose()
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- return {
- **self.send_stream.extra_attributes,
- **self.receive_stream.extra_attributes,
- }
- @dataclass(eq=False)
- class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
- """
- Combines two object streams into a single, bidirectional object stream.
- Extra attributes will be provided from both streams, with the receive stream
- providing the values in case of a conflict.
- :param ObjectSendStream send_stream: the sending object stream
- :param ObjectReceiveStream receive_stream: the receiving object stream
- """
- send_stream: ObjectSendStream[T_Item]
- receive_stream: ObjectReceiveStream[T_Item]
- async def receive(self) -> T_Item:
- return await self.receive_stream.receive()
- async def send(self, item: T_Item) -> None:
- await self.send_stream.send(item)
- async def send_eof(self) -> None:
- await self.send_stream.aclose()
- async def aclose(self) -> None:
- await self.send_stream.aclose()
- await self.receive_stream.aclose()
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- return {
- **self.send_stream.extra_attributes,
- **self.receive_stream.extra_attributes,
- }
- @dataclass(eq=False)
- class MultiListener(Generic[T_Stream], Listener[T_Stream]):
- """
- Combines multiple listeners into one, serving connections from all of them at once.
- Any MultiListeners in the given collection of listeners will have their listeners
- moved into this one.
- Extra attributes are provided from each listener, with each successive listener
- overriding any conflicting attributes from the previous one.
- :param listeners: listeners to serve
- :type listeners: Sequence[Listener[T_Stream]]
- """
- listeners: Sequence[Listener[T_Stream]]
- def __post_init__(self) -> None:
- listeners: list[Listener[T_Stream]] = []
- for listener in self.listeners:
- if isinstance(listener, MultiListener):
- listeners.extend(listener.listeners)
- del listener.listeners[:] # type: ignore[attr-defined]
- else:
- listeners.append(listener)
- self.listeners = listeners
- async def serve(
- self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
- ) -> None:
- from .. import create_task_group
- async with create_task_group() as tg:
- for listener in self.listeners:
- tg.start_soon(listener.serve, handler, task_group)
- async def aclose(self) -> None:
- for listener in self.listeners:
- await listener.aclose()
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- attributes: dict = {}
- for listener in self.listeners:
- attributes.update(listener.extra_attributes)
- return attributes
|