core.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import bisect
  2. import re
  3. import unicodedata
  4. from typing import Optional, Union
  5. from . import idnadata
  6. from .intranges import intranges_contain
  7. _virama_combining_class = 9
  8. _alabel_prefix = b"xn--"
  9. _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
  10. class IDNAError(UnicodeError):
  11. """Base exception for all IDNA-encoding related problems"""
  12. pass
  13. class IDNABidiError(IDNAError):
  14. """Exception when bidirectional requirements are not satisfied"""
  15. pass
  16. class InvalidCodepoint(IDNAError):
  17. """Exception when a disallowed or unallocated codepoint is used"""
  18. pass
  19. class InvalidCodepointContext(IDNAError):
  20. """Exception when the codepoint is not valid in the context it is used"""
  21. pass
  22. def _combining_class(cp: int) -> int:
  23. v = unicodedata.combining(chr(cp))
  24. if v == 0:
  25. if not unicodedata.name(chr(cp)):
  26. raise ValueError("Unknown character in unicodedata")
  27. return v
  28. def _is_script(cp: str, script: str) -> bool:
  29. return intranges_contain(ord(cp), idnadata.scripts[script])
  30. def _punycode(s: str) -> bytes:
  31. return s.encode("punycode")
  32. def _unot(s: int) -> str:
  33. return "U+{:04X}".format(s)
  34. def valid_label_length(label: Union[bytes, str]) -> bool:
  35. if len(label) > 63:
  36. return False
  37. return True
  38. def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
  39. if len(label) > (254 if trailing_dot else 253):
  40. return False
  41. return True
  42. def check_bidi(label: str, check_ltr: bool = False) -> bool:
  43. # Bidi rules should only be applied if string contains RTL characters
  44. bidi_label = False
  45. for idx, cp in enumerate(label, 1):
  46. direction = unicodedata.bidirectional(cp)
  47. if direction == "":
  48. # String likely comes from a newer version of Unicode
  49. raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx))
  50. if direction in ["R", "AL", "AN"]:
  51. bidi_label = True
  52. if not bidi_label and not check_ltr:
  53. return True
  54. # Bidi rule 1
  55. direction = unicodedata.bidirectional(label[0])
  56. if direction in ["R", "AL"]:
  57. rtl = True
  58. elif direction == "L":
  59. rtl = False
  60. else:
  61. raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label)))
  62. valid_ending = False
  63. number_type: Optional[str] = None
  64. for idx, cp in enumerate(label, 1):
  65. direction = unicodedata.bidirectional(cp)
  66. if rtl:
  67. # Bidi rule 2
  68. if direction not in [
  69. "R",
  70. "AL",
  71. "AN",
  72. "EN",
  73. "ES",
  74. "CS",
  75. "ET",
  76. "ON",
  77. "BN",
  78. "NSM",
  79. ]:
  80. raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx))
  81. # Bidi rule 3
  82. if direction in ["R", "AL", "EN", "AN"]:
  83. valid_ending = True
  84. elif direction != "NSM":
  85. valid_ending = False
  86. # Bidi rule 4
  87. if direction in ["AN", "EN"]:
  88. if not number_type:
  89. number_type = direction
  90. else:
  91. if number_type != direction:
  92. raise IDNABidiError("Can not mix numeral types in a right-to-left label")
  93. else:
  94. # Bidi rule 5
  95. if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]:
  96. raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx))
  97. # Bidi rule 6
  98. if direction in ["L", "EN"]:
  99. valid_ending = True
  100. elif direction != "NSM":
  101. valid_ending = False
  102. if not valid_ending:
  103. raise IDNABidiError("Label ends with illegal codepoint directionality")
  104. return True
  105. def check_initial_combiner(label: str) -> bool:
  106. if unicodedata.category(label[0])[0] == "M":
  107. raise IDNAError("Label begins with an illegal combining character")
  108. return True
  109. def check_hyphen_ok(label: str) -> bool:
  110. if label[2:4] == "--":
  111. raise IDNAError("Label has disallowed hyphens in 3rd and 4th position")
  112. if label[0] == "-" or label[-1] == "-":
  113. raise IDNAError("Label must not start or end with a hyphen")
  114. return True
  115. def check_nfc(label: str) -> None:
  116. if unicodedata.normalize("NFC", label) != label:
  117. raise IDNAError("Label must be in Normalization Form C")
  118. def valid_contextj(label: str, pos: int) -> bool:
  119. cp_value = ord(label[pos])
  120. if cp_value == 0x200C:
  121. if pos > 0:
  122. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  123. return True
  124. ok = False
  125. for i in range(pos - 1, -1, -1):
  126. joining_type = idnadata.joining_types.get(ord(label[i]))
  127. if joining_type == ord("T"):
  128. continue
  129. elif joining_type in [ord("L"), ord("D")]:
  130. ok = True
  131. break
  132. else:
  133. break
  134. if not ok:
  135. return False
  136. ok = False
  137. for i in range(pos + 1, len(label)):
  138. joining_type = idnadata.joining_types.get(ord(label[i]))
  139. if joining_type == ord("T"):
  140. continue
  141. elif joining_type in [ord("R"), ord("D")]:
  142. ok = True
  143. break
  144. else:
  145. break
  146. return ok
  147. if cp_value == 0x200D:
  148. if pos > 0:
  149. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  150. return True
  151. return False
  152. else:
  153. return False
  154. def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
  155. cp_value = ord(label[pos])
  156. if cp_value == 0x00B7:
  157. if 0 < pos < len(label) - 1:
  158. if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C:
  159. return True
  160. return False
  161. elif cp_value == 0x0375:
  162. if pos < len(label) - 1 and len(label) > 1:
  163. return _is_script(label[pos + 1], "Greek")
  164. return False
  165. elif cp_value == 0x05F3 or cp_value == 0x05F4:
  166. if pos > 0:
  167. return _is_script(label[pos - 1], "Hebrew")
  168. return False
  169. elif cp_value == 0x30FB:
  170. for cp in label:
  171. if cp == "\u30fb":
  172. continue
  173. if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"):
  174. return True
  175. return False
  176. elif 0x660 <= cp_value <= 0x669:
  177. for cp in label:
  178. if 0x6F0 <= ord(cp) <= 0x06F9:
  179. return False
  180. return True
  181. elif 0x6F0 <= cp_value <= 0x6F9:
  182. for cp in label:
  183. if 0x660 <= ord(cp) <= 0x0669:
  184. return False
  185. return True
  186. return False
  187. def check_label(label: Union[str, bytes, bytearray]) -> None:
  188. if isinstance(label, (bytes, bytearray)):
  189. label = label.decode("utf-8")
  190. if len(label) == 0:
  191. raise IDNAError("Empty Label")
  192. check_nfc(label)
  193. check_hyphen_ok(label)
  194. check_initial_combiner(label)
  195. for pos, cp in enumerate(label):
  196. cp_value = ord(cp)
  197. if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]):
  198. continue
  199. elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]):
  200. try:
  201. if not valid_contextj(label, pos):
  202. raise InvalidCodepointContext(
  203. "Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
  204. )
  205. except ValueError:
  206. raise IDNAError(
  207. "Unknown codepoint adjacent to joiner {} at position {} in {}".format(
  208. _unot(cp_value), pos + 1, repr(label)
  209. )
  210. )
  211. elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]):
  212. if not valid_contexto(label, pos):
  213. raise InvalidCodepointContext(
  214. "Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
  215. )
  216. else:
  217. raise InvalidCodepoint(
  218. "Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label))
  219. )
  220. check_bidi(label)
  221. def alabel(label: str) -> bytes:
  222. try:
  223. label_bytes = label.encode("ascii")
  224. ulabel(label_bytes)
  225. if not valid_label_length(label_bytes):
  226. raise IDNAError("Label too long")
  227. return label_bytes
  228. except UnicodeEncodeError:
  229. pass
  230. check_label(label)
  231. label_bytes = _alabel_prefix + _punycode(label)
  232. if not valid_label_length(label_bytes):
  233. raise IDNAError("Label too long")
  234. return label_bytes
  235. def ulabel(label: Union[str, bytes, bytearray]) -> str:
  236. if not isinstance(label, (bytes, bytearray)):
  237. try:
  238. label_bytes = label.encode("ascii")
  239. except UnicodeEncodeError:
  240. check_label(label)
  241. return label
  242. else:
  243. label_bytes = label
  244. label_bytes = label_bytes.lower()
  245. if label_bytes.startswith(_alabel_prefix):
  246. label_bytes = label_bytes[len(_alabel_prefix) :]
  247. if not label_bytes:
  248. raise IDNAError("Malformed A-label, no Punycode eligible content found")
  249. if label_bytes.decode("ascii")[-1] == "-":
  250. raise IDNAError("A-label must not end with a hyphen")
  251. else:
  252. check_label(label_bytes)
  253. return label_bytes.decode("ascii")
  254. try:
  255. label = label_bytes.decode("punycode")
  256. except UnicodeError:
  257. raise IDNAError("Invalid A-label")
  258. check_label(label)
  259. return label
  260. def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
  261. """Re-map the characters in the string according to UTS46 processing."""
  262. from .uts46data import uts46data
  263. output = ""
  264. for pos, char in enumerate(domain):
  265. code_point = ord(char)
  266. try:
  267. uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
  268. status = uts46row[1]
  269. replacement: Optional[str] = None
  270. if len(uts46row) == 3:
  271. replacement = uts46row[2]
  272. if (
  273. status == "V"
  274. or (status == "D" and not transitional)
  275. or (status == "3" and not std3_rules and replacement is None)
  276. ):
  277. output += char
  278. elif replacement is not None and (
  279. status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional)
  280. ):
  281. output += replacement
  282. elif status != "I":
  283. raise IndexError()
  284. except IndexError:
  285. raise InvalidCodepoint(
  286. "Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain))
  287. )
  288. return unicodedata.normalize("NFC", output)
  289. def encode(
  290. s: Union[str, bytes, bytearray],
  291. strict: bool = False,
  292. uts46: bool = False,
  293. std3_rules: bool = False,
  294. transitional: bool = False,
  295. ) -> bytes:
  296. if not isinstance(s, str):
  297. try:
  298. s = str(s, "ascii")
  299. except UnicodeDecodeError:
  300. raise IDNAError("should pass a unicode string to the function rather than a byte string.")
  301. if uts46:
  302. s = uts46_remap(s, std3_rules, transitional)
  303. trailing_dot = False
  304. result = []
  305. if strict:
  306. labels = s.split(".")
  307. else:
  308. labels = _unicode_dots_re.split(s)
  309. if not labels or labels == [""]:
  310. raise IDNAError("Empty domain")
  311. if labels[-1] == "":
  312. del labels[-1]
  313. trailing_dot = True
  314. for label in labels:
  315. s = alabel(label)
  316. if s:
  317. result.append(s)
  318. else:
  319. raise IDNAError("Empty label")
  320. if trailing_dot:
  321. result.append(b"")
  322. s = b".".join(result)
  323. if not valid_string_length(s, trailing_dot):
  324. raise IDNAError("Domain too long")
  325. return s
  326. def decode(
  327. s: Union[str, bytes, bytearray],
  328. strict: bool = False,
  329. uts46: bool = False,
  330. std3_rules: bool = False,
  331. ) -> str:
  332. try:
  333. if not isinstance(s, str):
  334. s = str(s, "ascii")
  335. except UnicodeDecodeError:
  336. raise IDNAError("Invalid ASCII in A-label")
  337. if uts46:
  338. s = uts46_remap(s, std3_rules, False)
  339. trailing_dot = False
  340. result = []
  341. if not strict:
  342. labels = _unicode_dots_re.split(s)
  343. else:
  344. labels = s.split(".")
  345. if not labels or labels == [""]:
  346. raise IDNAError("Empty domain")
  347. if not labels[-1]:
  348. del labels[-1]
  349. trailing_dot = True
  350. for label in labels:
  351. s = ulabel(label)
  352. if s:
  353. result.append(s)
  354. else:
  355. raise IDNAError("Empty label")
  356. if trailing_dot:
  357. result.append("")
  358. return ".".join(result)