123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- import asyncio
- import calendar
- import contextlib
- import datetime
- import heapq
- import itertools
- import os # noqa
- import pathlib
- import pickle
- import re
- import time
- import warnings
- from collections import defaultdict
- from http.cookies import BaseCookie, Morsel, SimpleCookie
- from typing import (
- DefaultDict,
- Dict,
- Iterable,
- Iterator,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
- cast,
- )
- from yarl import URL
- from .abc import AbstractCookieJar, ClearCookiePredicate
- from .helpers import is_ip_address
- from .typedefs import LooseCookies, PathLike, StrOrURL
- __all__ = ("CookieJar", "DummyCookieJar")
- CookieItem = Union[str, "Morsel[str]"]
- # We cache these string methods here as their use is in performance critical code.
- _FORMAT_PATH = "{}/{}".format
- _FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
- # The minimum number of scheduled cookie expirations before we start cleaning up
- # the expiration heap. This is a performance optimization to avoid cleaning up the
- # heap too often when there are only a few scheduled expirations.
- _MIN_SCHEDULED_COOKIE_EXPIRATION = 100
- class CookieJar(AbstractCookieJar):
- """Implements cookie storage adhering to RFC 6265."""
- DATE_TOKENS_RE = re.compile(
- r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
- r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
- )
- DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
- DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
- DATE_MONTH_RE = re.compile(
- "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
- re.I,
- )
- DATE_YEAR_RE = re.compile(r"(\d{2,4})")
- # calendar.timegm() fails for timestamps after datetime.datetime.max
- # Minus one as a loss of precision occurs when timestamp() is called.
- MAX_TIME = (
- int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
- )
- try:
- calendar.timegm(time.gmtime(MAX_TIME))
- except (OSError, ValueError):
- # Hit the maximum representable time on Windows
- # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
- # Throws ValueError on PyPy 3.9, OSError elsewhere
- MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
- except OverflowError:
- # #4515: datetime.max may not be representable on 32-bit platforms
- MAX_TIME = 2**31 - 1
- # Avoid minuses in the future, 3x faster
- SUB_MAX_TIME = MAX_TIME - 1
- def __init__(
- self,
- *,
- unsafe: bool = False,
- quote_cookie: bool = True,
- treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- ) -> None:
- super().__init__(loop=loop)
- self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
- SimpleCookie
- )
- self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
- defaultdict(dict)
- )
- self._host_only_cookies: Set[Tuple[str, str]] = set()
- self._unsafe = unsafe
- self._quote_cookie = quote_cookie
- if treat_as_secure_origin is None:
- treat_as_secure_origin = []
- elif isinstance(treat_as_secure_origin, URL):
- treat_as_secure_origin = [treat_as_secure_origin.origin()]
- elif isinstance(treat_as_secure_origin, str):
- treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
- else:
- treat_as_secure_origin = [
- URL(url).origin() if isinstance(url, str) else url.origin()
- for url in treat_as_secure_origin
- ]
- self._treat_as_secure_origin = treat_as_secure_origin
- self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
- self._expirations: Dict[Tuple[str, str, str], float] = {}
- @property
- def quote_cookie(self) -> bool:
- return self._quote_cookie
- def save(self, file_path: PathLike) -> None:
- file_path = pathlib.Path(file_path)
- with file_path.open(mode="wb") as f:
- pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
- def load(self, file_path: PathLike) -> None:
- file_path = pathlib.Path(file_path)
- with file_path.open(mode="rb") as f:
- self._cookies = pickle.load(f)
- def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
- if predicate is None:
- self._expire_heap.clear()
- self._cookies.clear()
- self._morsel_cache.clear()
- self._host_only_cookies.clear()
- self._expirations.clear()
- return
- now = time.time()
- to_del = [
- key
- for (domain, path), cookie in self._cookies.items()
- for name, morsel in cookie.items()
- if (
- (key := (domain, path, name)) in self._expirations
- and self._expirations[key] <= now
- )
- or predicate(morsel)
- ]
- if to_del:
- self._delete_cookies(to_del)
- def clear_domain(self, domain: str) -> None:
- self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
- def __iter__(self) -> "Iterator[Morsel[str]]":
- self._do_expiration()
- for val in self._cookies.values():
- yield from val.values()
- def __len__(self) -> int:
- """Return number of cookies.
- This function does not iterate self to avoid unnecessary expiration
- checks.
- """
- return sum(len(cookie.values()) for cookie in self._cookies.values())
- def _do_expiration(self) -> None:
- """Remove expired cookies."""
- if not (expire_heap_len := len(self._expire_heap)):
- return
- # If the expiration heap grows larger than the number expirations
- # times two, we clean it up to avoid keeping expired entries in
- # the heap and consuming memory. We guard this with a minimum
- # threshold to avoid cleaning up the heap too often when there are
- # only a few scheduled expirations.
- if (
- expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
- and expire_heap_len > len(self._expirations) * 2
- ):
- # Remove any expired entries from the expiration heap
- # that do not match the expiration time in the expirations
- # as it means the cookie has been re-added to the heap
- # with a different expiration time.
- self._expire_heap = [
- entry
- for entry in self._expire_heap
- if self._expirations.get(entry[1]) == entry[0]
- ]
- heapq.heapify(self._expire_heap)
- now = time.time()
- to_del: List[Tuple[str, str, str]] = []
- # Find any expired cookies and add them to the to-delete list
- while self._expire_heap:
- when, cookie_key = self._expire_heap[0]
- if when > now:
- break
- heapq.heappop(self._expire_heap)
- # Check if the cookie hasn't been re-added to the heap
- # with a different expiration time as it will be removed
- # later when it reaches the top of the heap and its
- # expiration time is met.
- if self._expirations.get(cookie_key) == when:
- to_del.append(cookie_key)
- if to_del:
- self._delete_cookies(to_del)
- def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
- for domain, path, name in to_del:
- self._host_only_cookies.discard((domain, name))
- self._cookies[(domain, path)].pop(name, None)
- self._morsel_cache[(domain, path)].pop(name, None)
- self._expirations.pop((domain, path, name), None)
- def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
- cookie_key = (domain, path, name)
- if self._expirations.get(cookie_key) == when:
- # Avoid adding duplicates to the heap
- return
- heapq.heappush(self._expire_heap, (when, cookie_key))
- self._expirations[cookie_key] = when
- def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
- """Update cookies."""
- hostname = response_url.raw_host
- if not self._unsafe and is_ip_address(hostname):
- # Don't accept cookies from IPs
- return
- if isinstance(cookies, Mapping):
- cookies = cookies.items()
- for name, cookie in cookies:
- if not isinstance(cookie, Morsel):
- tmp = SimpleCookie()
- tmp[name] = cookie # type: ignore[assignment]
- cookie = tmp[name]
- domain = cookie["domain"]
- # ignore domains with trailing dots
- if domain and domain[-1] == ".":
- domain = ""
- del cookie["domain"]
- if not domain and hostname is not None:
- # Set the cookie's domain to the response hostname
- # and set its host-only-flag
- self._host_only_cookies.add((hostname, name))
- domain = cookie["domain"] = hostname
- if domain and domain[0] == ".":
- # Remove leading dot
- domain = domain[1:]
- cookie["domain"] = domain
- if hostname and not self._is_domain_match(domain, hostname):
- # Setting cookies for different domains is not allowed
- continue
- path = cookie["path"]
- if not path or path[0] != "/":
- # Set the cookie's path to the response path
- path = response_url.path
- if not path.startswith("/"):
- path = "/"
- else:
- # Cut everything from the last slash to the end
- path = "/" + path[1 : path.rfind("/")]
- cookie["path"] = path
- path = path.rstrip("/")
- if max_age := cookie["max-age"]:
- try:
- delta_seconds = int(max_age)
- max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
- self._expire_cookie(max_age_expiration, domain, path, name)
- except ValueError:
- cookie["max-age"] = ""
- elif expires := cookie["expires"]:
- if expire_time := self._parse_date(expires):
- self._expire_cookie(expire_time, domain, path, name)
- else:
- cookie["expires"] = ""
- key = (domain, path)
- if self._cookies[key].get(name) != cookie:
- # Don't blow away the cache if the same
- # cookie gets set again
- self._cookies[key][name] = cookie
- self._morsel_cache[key].pop(name, None)
- self._do_expiration()
- def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
- """Returns this jar's cookies filtered by their attributes."""
- filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
- SimpleCookie() if self._quote_cookie else BaseCookie()
- )
- if not self._cookies:
- # Skip do_expiration() if there are no cookies.
- return filtered
- self._do_expiration()
- if not self._cookies:
- # Skip rest of function if no non-expired cookies.
- return filtered
- if type(request_url) is not URL:
- warnings.warn(
- "filter_cookies expects yarl.URL instances only,"
- f"and will stop working in 4.x, got {type(request_url)}",
- DeprecationWarning,
- stacklevel=2,
- )
- request_url = URL(request_url)
- hostname = request_url.raw_host or ""
- is_not_secure = request_url.scheme not in ("https", "wss")
- if is_not_secure and self._treat_as_secure_origin:
- request_origin = URL()
- with contextlib.suppress(ValueError):
- request_origin = request_url.origin()
- is_not_secure = request_origin not in self._treat_as_secure_origin
- # Send shared cookie
- for c in self._cookies[("", "")].values():
- filtered[c.key] = c.value
- if is_ip_address(hostname):
- if not self._unsafe:
- return filtered
- domains: Iterable[str] = (hostname,)
- else:
- # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
- domains = itertools.accumulate(
- reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
- )
- # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
- paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
- # Create every combination of (domain, path) pairs.
- pairs = itertools.product(domains, paths)
- path_len = len(request_url.path)
- # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
- for p in pairs:
- for name, cookie in self._cookies[p].items():
- domain = cookie["domain"]
- if (domain, name) in self._host_only_cookies and domain != hostname:
- continue
- # Skip edge case when the cookie has a trailing slash but request doesn't.
- if len(cookie["path"]) > path_len:
- continue
- if is_not_secure and cookie["secure"]:
- continue
- # We already built the Morsel so reuse it here
- if name in self._morsel_cache[p]:
- filtered[name] = self._morsel_cache[p][name]
- continue
- # It's critical we use the Morsel so the coded_value
- # (based on cookie version) is preserved
- mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
- mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
- self._morsel_cache[p][name] = mrsl_val
- filtered[name] = mrsl_val
- return filtered
- @staticmethod
- def _is_domain_match(domain: str, hostname: str) -> bool:
- """Implements domain matching adhering to RFC 6265."""
- if hostname == domain:
- return True
- if not hostname.endswith(domain):
- return False
- non_matching = hostname[: -len(domain)]
- if not non_matching.endswith("."):
- return False
- return not is_ip_address(hostname)
- @classmethod
- def _parse_date(cls, date_str: str) -> Optional[int]:
- """Implements date string parsing adhering to RFC 6265."""
- if not date_str:
- return None
- found_time = False
- found_day = False
- found_month = False
- found_year = False
- hour = minute = second = 0
- day = 0
- month = 0
- year = 0
- for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
- token = token_match.group("token")
- if not found_time:
- time_match = cls.DATE_HMS_TIME_RE.match(token)
- if time_match:
- found_time = True
- hour, minute, second = (int(s) for s in time_match.groups())
- continue
- if not found_day:
- day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
- if day_match:
- found_day = True
- day = int(day_match.group())
- continue
- if not found_month:
- month_match = cls.DATE_MONTH_RE.match(token)
- if month_match:
- found_month = True
- assert month_match.lastindex is not None
- month = month_match.lastindex
- continue
- if not found_year:
- year_match = cls.DATE_YEAR_RE.match(token)
- if year_match:
- found_year = True
- year = int(year_match.group())
- if 70 <= year <= 99:
- year += 1900
- elif 0 <= year <= 69:
- year += 2000
- if False in (found_day, found_month, found_year, found_time):
- return None
- if not 1 <= day <= 31:
- return None
- if year < 1601 or hour > 23 or minute > 59 or second > 59:
- return None
- return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
- class DummyCookieJar(AbstractCookieJar):
- """Implements a dummy cookie storage.
- It can be used with the ClientSession when no cookie processing is needed.
- """
- def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
- super().__init__(loop=loop)
- def __iter__(self) -> "Iterator[Morsel[str]]":
- while False:
- yield None
- def __len__(self) -> int:
- return 0
- @property
- def quote_cookie(self) -> bool:
- return True
- def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
- pass
- def clear_domain(self, domain: str) -> None:
- pass
- def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
- pass
- def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
- return SimpleCookie()
|