123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- from typing import Any, Callable, Generator, List
- import pytest
- from .._events import (
- ConnectionClosed,
- Data,
- EndOfMessage,
- Event,
- InformationalResponse,
- Request,
- Response,
- )
- from .._headers import Headers, normalize_and_validate
- from .._readers import (
- _obsolete_line_fold,
- ChunkedReader,
- ContentLengthReader,
- Http10Reader,
- READERS,
- )
- from .._receivebuffer import ReceiveBuffer
- from .._state import (
- CLIENT,
- CLOSED,
- DONE,
- IDLE,
- MIGHT_SWITCH_PROTOCOL,
- MUST_CLOSE,
- SEND_BODY,
- SEND_RESPONSE,
- SERVER,
- SWITCHED_PROTOCOL,
- )
- from .._util import LocalProtocolError
- from .._writers import (
- ChunkedWriter,
- ContentLengthWriter,
- Http10Writer,
- write_any_response,
- write_headers,
- write_request,
- WRITERS,
- )
- from .helpers import normalize_data_events
- SIMPLE_CASES = [
- (
- (CLIENT, IDLE),
- Request(
- method="GET",
- target="/a",
- headers=[("Host", "foo"), ("Connection", "close")],
- ),
- b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
- ),
- (
- (SERVER, SEND_RESPONSE),
- Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
- b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
- ),
- (
- (SERVER, SEND_RESPONSE),
- Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type]
- b"HTTP/1.1 200 OK\r\n\r\n",
- ),
- (
- (SERVER, SEND_RESPONSE),
- InformationalResponse(
- status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
- ),
- b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
- ),
- (
- (SERVER, SEND_RESPONSE),
- InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type]
- b"HTTP/1.1 101 Upgrade\r\n\r\n",
- ),
- ]
- def dowrite(writer: Callable[..., None], obj: Any) -> bytes:
- got_list: List[bytes] = []
- writer(obj, got_list.append)
- return b"".join(got_list)
- def tw(writer: Any, obj: Any, expected: Any) -> None:
- got = dowrite(writer, obj)
- assert got == expected
- def makebuf(data: bytes) -> ReceiveBuffer:
- buf = ReceiveBuffer()
- buf += data
- return buf
- def tr(reader: Any, data: bytes, expected: Any) -> None:
- def check(got: Any) -> None:
- assert got == expected
- # Headers should always be returned as bytes, not e.g. bytearray
- # https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478
- for name, value in getattr(got, "headers", []):
- assert type(name) is bytes
- assert type(value) is bytes
- # Simple: consume whole thing
- buf = makebuf(data)
- check(reader(buf))
- assert not buf
- # Incrementally growing buffer
- buf = ReceiveBuffer()
- for i in range(len(data)):
- assert reader(buf) is None
- buf += data[i : i + 1]
- check(reader(buf))
- # Trailing data
- buf = makebuf(data)
- buf += b"trailing"
- check(reader(buf))
- assert bytes(buf) == b"trailing"
- def test_writers_simple() -> None:
- for ((role, state), event, binary) in SIMPLE_CASES:
- tw(WRITERS[role, state], event, binary)
- def test_readers_simple() -> None:
- for ((role, state), event, binary) in SIMPLE_CASES:
- tr(READERS[role, state], binary, event)
- def test_writers_unusual() -> None:
- # Simple test of the write_headers utility routine
- tw(
- write_headers,
- normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
- b"foo: bar\r\nbaz: quux\r\n\r\n",
- )
- tw(write_headers, Headers([]), b"\r\n")
- # We understand HTTP/1.0, but we don't speak it
- with pytest.raises(LocalProtocolError):
- tw(
- write_request,
- Request(
- method="GET",
- target="/",
- headers=[("Host", "foo"), ("Connection", "close")],
- http_version="1.0",
- ),
- None,
- )
- with pytest.raises(LocalProtocolError):
- tw(
- write_any_response,
- Response(
- status_code=200, headers=[("Connection", "close")], http_version="1.0"
- ),
- None,
- )
- def test_readers_unusual() -> None:
- # Reading HTTP/1.0
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n",
- Request(
- method="HEAD",
- target="/foo",
- headers=[("Some", "header")],
- http_version="1.0",
- ),
- )
- # check no-headers, since it's only legal with HTTP/1.0
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.0\r\n\r\n",
- Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type]
- )
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n",
- Response(
- status_code=200,
- headers=[("Some", "header")],
- http_version="1.0",
- reason=b"OK",
- ),
- )
- # single-character header values (actually disallowed by the ABNF in RFC
- # 7230 -- this is a bug in the standard that we originally copied...)
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n",
- Response(
- status_code=200,
- headers=[("Foo", "a a a a a")],
- http_version="1.0",
- reason=b"OK",
- ),
- )
- # Empty headers -- also legal
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n",
- Response(
- status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
- ),
- )
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n",
- Response(
- status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
- ),
- )
- # Tolerate broken servers that leave off the response code
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n",
- Response(
- status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b""
- ),
- )
- # Tolerate headers line endings (\r\n and \n)
- # \n\r\b between headers and body
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n",
- Response(
- status_code=200,
- headers=[("SomeHeader", "val")],
- http_version="1.1",
- reason="OK",
- ),
- )
- # delimited only with \n
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n",
- Response(
- status_code=200,
- headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
- http_version="1.1",
- reason="OK",
- ),
- )
- # mixed \r\n and \n
- tr(
- READERS[SERVER, SEND_RESPONSE],
- b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n",
- Response(
- status_code=200,
- headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
- http_version="1.1",
- reason="OK",
- ),
- )
- # obsolete line folding
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n"
- b"Host: example.com\r\n"
- b"Some: multi-line\r\n"
- b" header\r\n"
- b"\tnonsense\r\n"
- b" \t \t\tI guess\r\n"
- b"Connection: close\r\n"
- b"More-nonsense: in the\r\n"
- b" last header \r\n\r\n",
- Request(
- method="HEAD",
- target="/foo",
- headers=[
- ("Host", "example.com"),
- ("Some", "multi-line header nonsense I guess"),
- ("Connection", "close"),
- ("More-nonsense", "in the last header"),
- ],
- ),
- )
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n",
- None,
- )
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n",
- None,
- )
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
- None,
- )
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
- None,
- )
- with pytest.raises(LocalProtocolError):
- tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None)
- def test__obsolete_line_fold_bytes() -> None:
- # _obsolete_line_fold has a defensive cast to bytearray, which is
- # necessary to protect against O(n^2) behavior in case anyone ever passes
- # in regular bytestrings... but right now we never pass in regular
- # bytestrings. so this test just exists to get some coverage on that
- # defensive cast.
- assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [
- b"aaa",
- bytearray(b"bbb ccc"),
- b"ddd",
- ]
- def _run_reader_iter(
- reader: Any, buf: bytes, do_eof: bool
- ) -> Generator[Any, None, None]:
- while True:
- event = reader(buf)
- if event is None:
- break
- yield event
- # body readers have undefined behavior after returning EndOfMessage,
- # because this changes the state so they don't get called again
- if type(event) is EndOfMessage:
- break
- if do_eof:
- assert not buf
- yield reader.read_eof()
- def _run_reader(*args: Any) -> List[Event]:
- events = list(_run_reader_iter(*args))
- return normalize_data_events(events)
- def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
- # Simple: consume whole thing
- print("Test 1")
- buf = makebuf(data)
- assert _run_reader(thunk(), buf, do_eof) == expected
- # Incrementally growing buffer
- print("Test 2")
- reader = thunk()
- buf = ReceiveBuffer()
- events = []
- for i in range(len(data)):
- events += _run_reader(reader, buf, False)
- buf += data[i : i + 1]
- events += _run_reader(reader, buf, do_eof)
- assert normalize_data_events(events) == expected
- is_complete = any(type(event) is EndOfMessage for event in expected)
- if is_complete and not do_eof:
- buf = makebuf(data + b"trailing")
- assert _run_reader(thunk(), buf, False) == expected
- def test_ContentLengthReader() -> None:
- t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()])
- t_body_reader(
- lambda: ContentLengthReader(10),
- b"0123456789",
- [Data(data=b"0123456789"), EndOfMessage()],
- )
- def test_Http10Reader() -> None:
- t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True)
- t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False)
- t_body_reader(
- Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True
- )
- def test_ChunkedReader() -> None:
- t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()])
- t_body_reader(
- ChunkedReader,
- b"0\r\nSome: header\r\n\r\n",
- [EndOfMessage(headers=[("Some", "header")])],
- )
- t_body_reader(
- ChunkedReader,
- b"5\r\n01234\r\n"
- + b"10\r\n0123456789abcdef\r\n"
- + b"0\r\n"
- + b"Some: header\r\n\r\n",
- [
- Data(data=b"012340123456789abcdef"),
- EndOfMessage(headers=[("Some", "header")]),
- ],
- )
- t_body_reader(
- ChunkedReader,
- b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n",
- [Data(data=b"012340123456789abcdef"), EndOfMessage()],
- )
- # handles upper and lowercase hex
- t_body_reader(
- ChunkedReader,
- b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n",
- [Data(data=b"x" * 0xAA), EndOfMessage()],
- )
- # refuses arbitrarily long chunk integers
- with pytest.raises(LocalProtocolError):
- # Technically this is legal HTTP/1.1, but we refuse to process chunk
- # sizes that don't fit into 20 characters of hex
- t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
- # refuses garbage in the chunk count
- with pytest.raises(LocalProtocolError):
- t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
- # handles (and discards) "chunk extensions" omg wtf
- t_body_reader(
- ChunkedReader,
- b"5; hello=there\r\n"
- + b"xxxxx"
- + b"\r\n"
- + b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n',
- [Data(data=b"xxxxx"), EndOfMessage()],
- )
- t_body_reader(
- ChunkedReader,
- b"5 \r\n01234\r\n" + b"0\r\n\r\n",
- [Data(data=b"01234"), EndOfMessage()],
- )
- def test_ContentLengthWriter() -> None:
- w = ContentLengthWriter(5)
- assert dowrite(w, Data(data=b"123")) == b"123"
- assert dowrite(w, Data(data=b"45")) == b"45"
- assert dowrite(w, EndOfMessage()) == b""
- w = ContentLengthWriter(5)
- with pytest.raises(LocalProtocolError):
- dowrite(w, Data(data=b"123456"))
- w = ContentLengthWriter(5)
- dowrite(w, Data(data=b"123"))
- with pytest.raises(LocalProtocolError):
- dowrite(w, Data(data=b"456"))
- w = ContentLengthWriter(5)
- dowrite(w, Data(data=b"123"))
- with pytest.raises(LocalProtocolError):
- dowrite(w, EndOfMessage())
- w = ContentLengthWriter(5)
- dowrite(w, Data(data=b"123")) == b"123"
- dowrite(w, Data(data=b"45")) == b"45"
- with pytest.raises(LocalProtocolError):
- dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
- def test_ChunkedWriter() -> None:
- w = ChunkedWriter()
- assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n"
- assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n"
- assert dowrite(w, Data(data=b"")) == b""
- assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n"
- assert (
- dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
- == b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
- )
- def test_Http10Writer() -> None:
- w = Http10Writer()
- assert dowrite(w, Data(data=b"1234")) == b"1234"
- assert dowrite(w, EndOfMessage()) == b""
- with pytest.raises(LocalProtocolError):
- dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
- def test_reject_garbage_after_request_line() -> None:
- with pytest.raises(LocalProtocolError):
- tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None)
- def test_reject_garbage_after_response_line() -> None:
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n",
- None,
- )
- def test_reject_garbage_in_header_line() -> None:
- with pytest.raises(LocalProtocolError):
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n",
- None,
- )
- def test_reject_non_vchar_in_path() -> None:
- for bad_char in b"\x00\x20\x7f\xee":
- message = bytearray(b"HEAD /")
- message.append(bad_char)
- message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n")
- with pytest.raises(LocalProtocolError):
- tr(READERS[CLIENT, IDLE], message, None)
- # https://github.com/python-hyper/h11/issues/57
- def test_allow_some_garbage_in_cookies() -> None:
- tr(
- READERS[CLIENT, IDLE],
- b"HEAD /foo HTTP/1.1\r\n"
- b"Host: foo\r\n"
- b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n"
- b"\r\n",
- Request(
- method="HEAD",
- target="/foo",
- headers=[
- ("Host", "foo"),
- ("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"),
- ],
- ),
- )
- def test_host_comes_first() -> None:
- tw(
- write_headers,
- normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
- b"Host: example.com\r\nfoo: bar\r\n\r\n",
- )
|