_quoting_py.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import codecs
  2. import re
  3. from string import ascii_letters, ascii_lowercase, digits
  4. from typing import Union, cast, overload
  5. BASCII_LOWERCASE = ascii_lowercase.encode("ascii")
  6. BPCT_ALLOWED = {f"%{i:02X}".encode("ascii") for i in range(256)}
  7. GEN_DELIMS = ":/?#[]@"
  8. SUB_DELIMS_WITHOUT_QS = "!$'()*,"
  9. SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + "+&=;"
  10. RESERVED = GEN_DELIMS + SUB_DELIMS
  11. UNRESERVED = ascii_letters + digits + "-._~"
  12. ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS
  13. _IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]")
  14. _IS_HEX_STR = re.compile("[A-Fa-f0-9][A-Fa-f0-9]")
  15. utf8_decoder = codecs.getincrementaldecoder("utf-8")
  16. class _Quoter:
  17. def __init__(
  18. self,
  19. *,
  20. safe: str = "",
  21. protected: str = "",
  22. qs: bool = False,
  23. requote: bool = True,
  24. ) -> None:
  25. self._safe = safe
  26. self._protected = protected
  27. self._qs = qs
  28. self._requote = requote
  29. @overload
  30. def __call__(self, val: str) -> str: ...
  31. @overload
  32. def __call__(self, val: None) -> None: ...
  33. def __call__(self, val: Union[str, None]) -> Union[str, None]:
  34. if val is None:
  35. return None
  36. if not isinstance(val, str):
  37. raise TypeError("Argument should be str")
  38. if not val:
  39. return ""
  40. bval = val.encode("utf8", errors="ignore")
  41. ret = bytearray()
  42. pct = bytearray()
  43. safe = self._safe
  44. safe += ALLOWED
  45. if not self._qs:
  46. safe += "+&=;"
  47. safe += self._protected
  48. bsafe = safe.encode("ascii")
  49. idx = 0
  50. while idx < len(bval):
  51. ch = bval[idx]
  52. idx += 1
  53. if pct:
  54. if ch in BASCII_LOWERCASE:
  55. ch = ch - 32 # convert to uppercase
  56. pct.append(ch)
  57. if len(pct) == 3: # pragma: no branch # peephole optimizer
  58. buf = pct[1:]
  59. if not _IS_HEX.match(buf):
  60. ret.extend(b"%25")
  61. pct.clear()
  62. idx -= 2
  63. continue
  64. try:
  65. unquoted = chr(int(pct[1:].decode("ascii"), base=16))
  66. except ValueError:
  67. ret.extend(b"%25")
  68. pct.clear()
  69. idx -= 2
  70. continue
  71. if unquoted in self._protected:
  72. ret.extend(pct)
  73. elif unquoted in safe:
  74. ret.append(ord(unquoted))
  75. else:
  76. ret.extend(pct)
  77. pct.clear()
  78. # special case, if we have only one char after "%"
  79. elif len(pct) == 2 and idx == len(bval):
  80. ret.extend(b"%25")
  81. pct.clear()
  82. idx -= 1
  83. continue
  84. elif ch == ord("%") and self._requote:
  85. pct.clear()
  86. pct.append(ch)
  87. # special case if "%" is last char
  88. if idx == len(bval):
  89. ret.extend(b"%25")
  90. continue
  91. if self._qs and ch == ord(" "):
  92. ret.append(ord("+"))
  93. continue
  94. if ch in bsafe:
  95. ret.append(ch)
  96. continue
  97. ret.extend((f"%{ch:02X}").encode("ascii"))
  98. ret2 = ret.decode("ascii")
  99. if ret2 == val:
  100. return val
  101. return ret2
  102. class _Unquoter:
  103. def __init__(
  104. self,
  105. *,
  106. ignore: str = "",
  107. unsafe: str = "",
  108. qs: bool = False,
  109. plus: bool = False,
  110. ) -> None:
  111. self._ignore = ignore
  112. self._unsafe = unsafe
  113. self._qs = qs
  114. self._plus = plus # to match urllib.parse.unquote_plus
  115. self._quoter = _Quoter()
  116. self._qs_quoter = _Quoter(qs=True)
  117. @overload
  118. def __call__(self, val: str) -> str: ...
  119. @overload
  120. def __call__(self, val: None) -> None: ...
  121. def __call__(self, val: Union[str, None]) -> Union[str, None]:
  122. if val is None:
  123. return None
  124. if not isinstance(val, str):
  125. raise TypeError("Argument should be str")
  126. if not val:
  127. return ""
  128. decoder = cast(codecs.BufferedIncrementalDecoder, utf8_decoder())
  129. ret = []
  130. idx = 0
  131. while idx < len(val):
  132. ch = val[idx]
  133. idx += 1
  134. if ch == "%" and idx <= len(val) - 2:
  135. pct = val[idx : idx + 2]
  136. if _IS_HEX_STR.fullmatch(pct):
  137. b = bytes([int(pct, base=16)])
  138. idx += 2
  139. try:
  140. unquoted = decoder.decode(b)
  141. except UnicodeDecodeError:
  142. start_pct = idx - 3 - len(decoder.buffer) * 3
  143. ret.append(val[start_pct : idx - 3])
  144. decoder.reset()
  145. try:
  146. unquoted = decoder.decode(b)
  147. except UnicodeDecodeError:
  148. ret.append(val[idx - 3 : idx])
  149. continue
  150. if not unquoted:
  151. continue
  152. if self._qs and unquoted in "+=&;":
  153. to_add = self._qs_quoter(unquoted)
  154. if to_add is None: # pragma: no cover
  155. raise RuntimeError("Cannot quote None")
  156. ret.append(to_add)
  157. elif unquoted in self._unsafe or unquoted in self._ignore:
  158. to_add = self._quoter(unquoted)
  159. if to_add is None: # pragma: no cover
  160. raise RuntimeError("Cannot quote None")
  161. ret.append(to_add)
  162. else:
  163. ret.append(unquoted)
  164. continue
  165. if decoder.buffer:
  166. start_pct = idx - 1 - len(decoder.buffer) * 3
  167. ret.append(val[start_pct : idx - 1])
  168. decoder.reset()
  169. if ch == "+":
  170. if (not self._qs and not self._plus) or ch in self._unsafe:
  171. ret.append("+")
  172. else:
  173. ret.append(" ")
  174. continue
  175. if ch in self._unsafe:
  176. ret.append("%")
  177. h = hex(ord(ch)).upper()[2:]
  178. for ch in h:
  179. ret.append(ch)
  180. continue
  181. ret.append(ch)
  182. if decoder.buffer:
  183. ret.append(val[-len(decoder.buffer) * 3 :])
  184. ret2 = "".join(ret)
  185. if ret2 == val:
  186. return val
  187. return ret2