cookiejar.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import asyncio
  2. import calendar
  3. import contextlib
  4. import datetime
  5. import heapq
  6. import itertools
  7. import os # noqa
  8. import pathlib
  9. import pickle
  10. import re
  11. import time
  12. import warnings
  13. from collections import defaultdict
  14. from http.cookies import BaseCookie, Morsel, SimpleCookie
  15. from typing import (
  16. DefaultDict,
  17. Dict,
  18. Iterable,
  19. Iterator,
  20. List,
  21. Mapping,
  22. Optional,
  23. Set,
  24. Tuple,
  25. Union,
  26. cast,
  27. )
  28. from yarl import URL
  29. from .abc import AbstractCookieJar, ClearCookiePredicate
  30. from .helpers import is_ip_address
  31. from .typedefs import LooseCookies, PathLike, StrOrURL
  32. __all__ = ("CookieJar", "DummyCookieJar")
  33. CookieItem = Union[str, "Morsel[str]"]
  34. # We cache these string methods here as their use is in performance critical code.
  35. _FORMAT_PATH = "{}/{}".format
  36. _FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
  37. # The minimum number of scheduled cookie expirations before we start cleaning up
  38. # the expiration heap. This is a performance optimization to avoid cleaning up the
  39. # heap too often when there are only a few scheduled expirations.
  40. _MIN_SCHEDULED_COOKIE_EXPIRATION = 100
  41. class CookieJar(AbstractCookieJar):
  42. """Implements cookie storage adhering to RFC 6265."""
  43. DATE_TOKENS_RE = re.compile(
  44. r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
  45. r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
  46. )
  47. DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
  48. DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
  49. DATE_MONTH_RE = re.compile(
  50. "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
  51. re.I,
  52. )
  53. DATE_YEAR_RE = re.compile(r"(\d{2,4})")
  54. # calendar.timegm() fails for timestamps after datetime.datetime.max
  55. # Minus one as a loss of precision occurs when timestamp() is called.
  56. MAX_TIME = (
  57. int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
  58. )
  59. try:
  60. calendar.timegm(time.gmtime(MAX_TIME))
  61. except (OSError, ValueError):
  62. # Hit the maximum representable time on Windows
  63. # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
  64. # Throws ValueError on PyPy 3.9, OSError elsewhere
  65. MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
  66. except OverflowError:
  67. # #4515: datetime.max may not be representable on 32-bit platforms
  68. MAX_TIME = 2**31 - 1
  69. # Avoid minuses in the future, 3x faster
  70. SUB_MAX_TIME = MAX_TIME - 1
  71. def __init__(
  72. self,
  73. *,
  74. unsafe: bool = False,
  75. quote_cookie: bool = True,
  76. treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
  77. loop: Optional[asyncio.AbstractEventLoop] = None,
  78. ) -> None:
  79. super().__init__(loop=loop)
  80. self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
  81. SimpleCookie
  82. )
  83. self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
  84. defaultdict(dict)
  85. )
  86. self._host_only_cookies: Set[Tuple[str, str]] = set()
  87. self._unsafe = unsafe
  88. self._quote_cookie = quote_cookie
  89. if treat_as_secure_origin is None:
  90. treat_as_secure_origin = []
  91. elif isinstance(treat_as_secure_origin, URL):
  92. treat_as_secure_origin = [treat_as_secure_origin.origin()]
  93. elif isinstance(treat_as_secure_origin, str):
  94. treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
  95. else:
  96. treat_as_secure_origin = [
  97. URL(url).origin() if isinstance(url, str) else url.origin()
  98. for url in treat_as_secure_origin
  99. ]
  100. self._treat_as_secure_origin = treat_as_secure_origin
  101. self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
  102. self._expirations: Dict[Tuple[str, str, str], float] = {}
  103. @property
  104. def quote_cookie(self) -> bool:
  105. return self._quote_cookie
  106. def save(self, file_path: PathLike) -> None:
  107. file_path = pathlib.Path(file_path)
  108. with file_path.open(mode="wb") as f:
  109. pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
  110. def load(self, file_path: PathLike) -> None:
  111. file_path = pathlib.Path(file_path)
  112. with file_path.open(mode="rb") as f:
  113. self._cookies = pickle.load(f)
  114. def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
  115. if predicate is None:
  116. self._expire_heap.clear()
  117. self._cookies.clear()
  118. self._morsel_cache.clear()
  119. self._host_only_cookies.clear()
  120. self._expirations.clear()
  121. return
  122. now = time.time()
  123. to_del = [
  124. key
  125. for (domain, path), cookie in self._cookies.items()
  126. for name, morsel in cookie.items()
  127. if (
  128. (key := (domain, path, name)) in self._expirations
  129. and self._expirations[key] <= now
  130. )
  131. or predicate(morsel)
  132. ]
  133. if to_del:
  134. self._delete_cookies(to_del)
  135. def clear_domain(self, domain: str) -> None:
  136. self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
  137. def __iter__(self) -> "Iterator[Morsel[str]]":
  138. self._do_expiration()
  139. for val in self._cookies.values():
  140. yield from val.values()
  141. def __len__(self) -> int:
  142. """Return number of cookies.
  143. This function does not iterate self to avoid unnecessary expiration
  144. checks.
  145. """
  146. return sum(len(cookie.values()) for cookie in self._cookies.values())
  147. def _do_expiration(self) -> None:
  148. """Remove expired cookies."""
  149. if not (expire_heap_len := len(self._expire_heap)):
  150. return
  151. # If the expiration heap grows larger than the number expirations
  152. # times two, we clean it up to avoid keeping expired entries in
  153. # the heap and consuming memory. We guard this with a minimum
  154. # threshold to avoid cleaning up the heap too often when there are
  155. # only a few scheduled expirations.
  156. if (
  157. expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
  158. and expire_heap_len > len(self._expirations) * 2
  159. ):
  160. # Remove any expired entries from the expiration heap
  161. # that do not match the expiration time in the expirations
  162. # as it means the cookie has been re-added to the heap
  163. # with a different expiration time.
  164. self._expire_heap = [
  165. entry
  166. for entry in self._expire_heap
  167. if self._expirations.get(entry[1]) == entry[0]
  168. ]
  169. heapq.heapify(self._expire_heap)
  170. now = time.time()
  171. to_del: List[Tuple[str, str, str]] = []
  172. # Find any expired cookies and add them to the to-delete list
  173. while self._expire_heap:
  174. when, cookie_key = self._expire_heap[0]
  175. if when > now:
  176. break
  177. heapq.heappop(self._expire_heap)
  178. # Check if the cookie hasn't been re-added to the heap
  179. # with a different expiration time as it will be removed
  180. # later when it reaches the top of the heap and its
  181. # expiration time is met.
  182. if self._expirations.get(cookie_key) == when:
  183. to_del.append(cookie_key)
  184. if to_del:
  185. self._delete_cookies(to_del)
  186. def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
  187. for domain, path, name in to_del:
  188. self._host_only_cookies.discard((domain, name))
  189. self._cookies[(domain, path)].pop(name, None)
  190. self._morsel_cache[(domain, path)].pop(name, None)
  191. self._expirations.pop((domain, path, name), None)
  192. def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
  193. cookie_key = (domain, path, name)
  194. if self._expirations.get(cookie_key) == when:
  195. # Avoid adding duplicates to the heap
  196. return
  197. heapq.heappush(self._expire_heap, (when, cookie_key))
  198. self._expirations[cookie_key] = when
  199. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  200. """Update cookies."""
  201. hostname = response_url.raw_host
  202. if not self._unsafe and is_ip_address(hostname):
  203. # Don't accept cookies from IPs
  204. return
  205. if isinstance(cookies, Mapping):
  206. cookies = cookies.items()
  207. for name, cookie in cookies:
  208. if not isinstance(cookie, Morsel):
  209. tmp = SimpleCookie()
  210. tmp[name] = cookie # type: ignore[assignment]
  211. cookie = tmp[name]
  212. domain = cookie["domain"]
  213. # ignore domains with trailing dots
  214. if domain and domain[-1] == ".":
  215. domain = ""
  216. del cookie["domain"]
  217. if not domain and hostname is not None:
  218. # Set the cookie's domain to the response hostname
  219. # and set its host-only-flag
  220. self._host_only_cookies.add((hostname, name))
  221. domain = cookie["domain"] = hostname
  222. if domain and domain[0] == ".":
  223. # Remove leading dot
  224. domain = domain[1:]
  225. cookie["domain"] = domain
  226. if hostname and not self._is_domain_match(domain, hostname):
  227. # Setting cookies for different domains is not allowed
  228. continue
  229. path = cookie["path"]
  230. if not path or path[0] != "/":
  231. # Set the cookie's path to the response path
  232. path = response_url.path
  233. if not path.startswith("/"):
  234. path = "/"
  235. else:
  236. # Cut everything from the last slash to the end
  237. path = "/" + path[1 : path.rfind("/")]
  238. cookie["path"] = path
  239. path = path.rstrip("/")
  240. if max_age := cookie["max-age"]:
  241. try:
  242. delta_seconds = int(max_age)
  243. max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
  244. self._expire_cookie(max_age_expiration, domain, path, name)
  245. except ValueError:
  246. cookie["max-age"] = ""
  247. elif expires := cookie["expires"]:
  248. if expire_time := self._parse_date(expires):
  249. self._expire_cookie(expire_time, domain, path, name)
  250. else:
  251. cookie["expires"] = ""
  252. key = (domain, path)
  253. if self._cookies[key].get(name) != cookie:
  254. # Don't blow away the cache if the same
  255. # cookie gets set again
  256. self._cookies[key][name] = cookie
  257. self._morsel_cache[key].pop(name, None)
  258. self._do_expiration()
  259. def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
  260. """Returns this jar's cookies filtered by their attributes."""
  261. filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
  262. SimpleCookie() if self._quote_cookie else BaseCookie()
  263. )
  264. if not self._cookies:
  265. # Skip do_expiration() if there are no cookies.
  266. return filtered
  267. self._do_expiration()
  268. if not self._cookies:
  269. # Skip rest of function if no non-expired cookies.
  270. return filtered
  271. if type(request_url) is not URL:
  272. warnings.warn(
  273. "filter_cookies expects yarl.URL instances only,"
  274. f"and will stop working in 4.x, got {type(request_url)}",
  275. DeprecationWarning,
  276. stacklevel=2,
  277. )
  278. request_url = URL(request_url)
  279. hostname = request_url.raw_host or ""
  280. is_not_secure = request_url.scheme not in ("https", "wss")
  281. if is_not_secure and self._treat_as_secure_origin:
  282. request_origin = URL()
  283. with contextlib.suppress(ValueError):
  284. request_origin = request_url.origin()
  285. is_not_secure = request_origin not in self._treat_as_secure_origin
  286. # Send shared cookie
  287. for c in self._cookies[("", "")].values():
  288. filtered[c.key] = c.value
  289. if is_ip_address(hostname):
  290. if not self._unsafe:
  291. return filtered
  292. domains: Iterable[str] = (hostname,)
  293. else:
  294. # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
  295. domains = itertools.accumulate(
  296. reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
  297. )
  298. # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
  299. paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
  300. # Create every combination of (domain, path) pairs.
  301. pairs = itertools.product(domains, paths)
  302. path_len = len(request_url.path)
  303. # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
  304. for p in pairs:
  305. for name, cookie in self._cookies[p].items():
  306. domain = cookie["domain"]
  307. if (domain, name) in self._host_only_cookies and domain != hostname:
  308. continue
  309. # Skip edge case when the cookie has a trailing slash but request doesn't.
  310. if len(cookie["path"]) > path_len:
  311. continue
  312. if is_not_secure and cookie["secure"]:
  313. continue
  314. # We already built the Morsel so reuse it here
  315. if name in self._morsel_cache[p]:
  316. filtered[name] = self._morsel_cache[p][name]
  317. continue
  318. # It's critical we use the Morsel so the coded_value
  319. # (based on cookie version) is preserved
  320. mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
  321. mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
  322. self._morsel_cache[p][name] = mrsl_val
  323. filtered[name] = mrsl_val
  324. return filtered
  325. @staticmethod
  326. def _is_domain_match(domain: str, hostname: str) -> bool:
  327. """Implements domain matching adhering to RFC 6265."""
  328. if hostname == domain:
  329. return True
  330. if not hostname.endswith(domain):
  331. return False
  332. non_matching = hostname[: -len(domain)]
  333. if not non_matching.endswith("."):
  334. return False
  335. return not is_ip_address(hostname)
  336. @classmethod
  337. def _parse_date(cls, date_str: str) -> Optional[int]:
  338. """Implements date string parsing adhering to RFC 6265."""
  339. if not date_str:
  340. return None
  341. found_time = False
  342. found_day = False
  343. found_month = False
  344. found_year = False
  345. hour = minute = second = 0
  346. day = 0
  347. month = 0
  348. year = 0
  349. for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
  350. token = token_match.group("token")
  351. if not found_time:
  352. time_match = cls.DATE_HMS_TIME_RE.match(token)
  353. if time_match:
  354. found_time = True
  355. hour, minute, second = (int(s) for s in time_match.groups())
  356. continue
  357. if not found_day:
  358. day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
  359. if day_match:
  360. found_day = True
  361. day = int(day_match.group())
  362. continue
  363. if not found_month:
  364. month_match = cls.DATE_MONTH_RE.match(token)
  365. if month_match:
  366. found_month = True
  367. assert month_match.lastindex is not None
  368. month = month_match.lastindex
  369. continue
  370. if not found_year:
  371. year_match = cls.DATE_YEAR_RE.match(token)
  372. if year_match:
  373. found_year = True
  374. year = int(year_match.group())
  375. if 70 <= year <= 99:
  376. year += 1900
  377. elif 0 <= year <= 69:
  378. year += 2000
  379. if False in (found_day, found_month, found_year, found_time):
  380. return None
  381. if not 1 <= day <= 31:
  382. return None
  383. if year < 1601 or hour > 23 or minute > 59 or second > 59:
  384. return None
  385. return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
  386. class DummyCookieJar(AbstractCookieJar):
  387. """Implements a dummy cookie storage.
  388. It can be used with the ClientSession when no cookie processing is needed.
  389. """
  390. def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
  391. super().__init__(loop=loop)
  392. def __iter__(self) -> "Iterator[Morsel[str]]":
  393. while False:
  394. yield None
  395. def __len__(self) -> int:
  396. return 0
  397. @property
  398. def quote_cookie(self) -> bool:
  399. return True
  400. def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
  401. pass
  402. def clear_domain(self, domain: str) -> None:
  403. pass
  404. def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
  405. pass
  406. def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
  407. return SimpleCookie()