helpers.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from typing import cast, List, Type, Union, ValuesView
  2. from .._connection import Connection, NEED_DATA, PAUSED
  3. from .._events import (
  4. ConnectionClosed,
  5. Data,
  6. EndOfMessage,
  7. Event,
  8. InformationalResponse,
  9. Request,
  10. Response,
  11. )
  12. from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER
  13. from .._util import Sentinel
  14. try:
  15. from typing import Literal
  16. except ImportError:
  17. from typing_extensions import Literal # type: ignore
  18. def get_all_events(conn: Connection) -> List[Event]:
  19. got_events = []
  20. while True:
  21. event = conn.next_event()
  22. if event in (NEED_DATA, PAUSED):
  23. break
  24. event = cast(Event, event)
  25. got_events.append(event)
  26. if type(event) is ConnectionClosed:
  27. break
  28. return got_events
  29. def receive_and_get(conn: Connection, data: bytes) -> List[Event]:
  30. conn.receive_data(data)
  31. return get_all_events(conn)
  32. # Merges adjacent Data events, converts payloads to bytestrings, and removes
  33. # chunk boundaries.
  34. def normalize_data_events(in_events: List[Event]) -> List[Event]:
  35. out_events: List[Event] = []
  36. for event in in_events:
  37. if type(event) is Data:
  38. event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
  39. if out_events and type(out_events[-1]) is type(event) is Data:
  40. out_events[-1] = Data(
  41. data=out_events[-1].data + event.data,
  42. chunk_start=out_events[-1].chunk_start,
  43. chunk_end=out_events[-1].chunk_end,
  44. )
  45. else:
  46. out_events.append(event)
  47. return out_events
  48. # Given that we want to write tests that push some events through a Connection
  49. # and check that its state updates appropriately... we might as make a habit
  50. # of pushing them through two Connections with a fake network link in
  51. # between.
  52. class ConnectionPair:
  53. def __init__(self) -> None:
  54. self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)}
  55. self.other = {CLIENT: SERVER, SERVER: CLIENT}
  56. @property
  57. def conns(self) -> ValuesView[Connection]:
  58. return self.conn.values()
  59. # expect="match" if expect=send_events; expect=[...] to say what expected
  60. def send(
  61. self,
  62. role: Type[Sentinel],
  63. send_events: Union[List[Event], Event],
  64. expect: Union[List[Event], Event, Literal["match"]] = "match",
  65. ) -> bytes:
  66. if not isinstance(send_events, list):
  67. send_events = [send_events]
  68. data = b""
  69. closed = False
  70. for send_event in send_events:
  71. new_data = self.conn[role].send(send_event)
  72. if new_data is None:
  73. closed = True
  74. else:
  75. data += new_data
  76. # send uses b"" to mean b"", and None to mean closed
  77. # receive uses b"" to mean closed, and None to mean "try again"
  78. # so we have to translate between the two conventions
  79. if data:
  80. self.conn[self.other[role]].receive_data(data)
  81. if closed:
  82. self.conn[self.other[role]].receive_data(b"")
  83. got_events = get_all_events(self.conn[self.other[role]])
  84. if expect == "match":
  85. expect = send_events
  86. if not isinstance(expect, list):
  87. expect = [expect]
  88. assert got_events == expect
  89. return data