test_util.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import re
  2. import sys
  3. import traceback
  4. from typing import NoReturn
  5. import pytest
  6. from .._util import (
  7. bytesify,
  8. LocalProtocolError,
  9. ProtocolError,
  10. RemoteProtocolError,
  11. Sentinel,
  12. validate,
  13. )
  14. def test_ProtocolError() -> None:
  15. with pytest.raises(TypeError):
  16. ProtocolError("abstract base class")
  17. def test_LocalProtocolError() -> None:
  18. try:
  19. raise LocalProtocolError("foo")
  20. except LocalProtocolError as e:
  21. assert str(e) == "foo"
  22. assert e.error_status_hint == 400
  23. try:
  24. raise LocalProtocolError("foo", error_status_hint=418)
  25. except LocalProtocolError as e:
  26. assert str(e) == "foo"
  27. assert e.error_status_hint == 418
  28. def thunk() -> NoReturn:
  29. raise LocalProtocolError("a", error_status_hint=420)
  30. try:
  31. try:
  32. thunk()
  33. except LocalProtocolError as exc1:
  34. orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
  35. exc1._reraise_as_remote_protocol_error()
  36. except RemoteProtocolError as exc2:
  37. assert type(exc2) is RemoteProtocolError
  38. assert exc2.args == ("a",)
  39. assert exc2.error_status_hint == 420
  40. new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
  41. assert new_traceback.endswith(orig_traceback)
  42. def test_validate() -> None:
  43. my_re = re.compile(rb"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
  44. with pytest.raises(LocalProtocolError):
  45. validate(my_re, b"0.")
  46. groups = validate(my_re, b"0.1")
  47. assert groups == {"group1": b"0", "group2": b"1"}
  48. # successful partial matches are an error - must match whole string
  49. with pytest.raises(LocalProtocolError):
  50. validate(my_re, b"0.1xx")
  51. with pytest.raises(LocalProtocolError):
  52. validate(my_re, b"0.1\n")
  53. def test_validate_formatting() -> None:
  54. my_re = re.compile(rb"foo")
  55. with pytest.raises(LocalProtocolError) as excinfo:
  56. validate(my_re, b"", "oops")
  57. assert "oops" in str(excinfo.value)
  58. with pytest.raises(LocalProtocolError) as excinfo:
  59. validate(my_re, b"", "oops {}")
  60. assert "oops {}" in str(excinfo.value)
  61. with pytest.raises(LocalProtocolError) as excinfo:
  62. validate(my_re, b"", "oops {} xx", 10)
  63. assert "oops 10 xx" in str(excinfo.value)
  64. def test_make_sentinel() -> None:
  65. class S(Sentinel, metaclass=Sentinel):
  66. pass
  67. assert repr(S) == "S"
  68. assert S == S
  69. assert type(S).__name__ == "S"
  70. assert S in {S}
  71. assert type(S) is S
  72. class S2(Sentinel, metaclass=Sentinel):
  73. pass
  74. assert repr(S2) == "S2"
  75. assert S != S2
  76. assert S not in {S2}
  77. assert type(S) is not type(S2)
  78. def test_bytesify() -> None:
  79. assert bytesify(b"123") == b"123"
  80. assert bytesify(bytearray(b"123")) == b"123"
  81. assert bytesify("123") == b"123"
  82. with pytest.raises(UnicodeEncodeError):
  83. bytesify("\u1234")
  84. with pytest.raises(TypeError):
  85. bytesify(10)