_auth.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. from __future__ import annotations
  2. import hashlib
  3. import os
  4. import re
  5. import time
  6. import typing
  7. from base64 import b64encode
  8. from urllib.request import parse_http_list
  9. from ._exceptions import ProtocolError
  10. from ._models import Cookies, Request, Response
  11. from ._utils import to_bytes, to_str, unquote
  12. if typing.TYPE_CHECKING: # pragma: no cover
  13. from hashlib import _Hash
  14. __all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"]
  15. class Auth:
  16. """
  17. Base class for all authentication schemes.
  18. To implement a custom authentication scheme, subclass `Auth` and override
  19. the `.auth_flow()` method.
  20. If the authentication scheme does I/O such as disk access or network calls, or uses
  21. synchronization primitives such as locks, you should override `.sync_auth_flow()`
  22. and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized
  23. implementations that will be used by `Client` and `AsyncClient` respectively.
  24. """
  25. requires_request_body = False
  26. requires_response_body = False
  27. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  28. """
  29. Execute the authentication flow.
  30. To dispatch a request, `yield` it:
  31. ```
  32. yield request
  33. ```
  34. The client will `.send()` the response back into the flow generator. You can
  35. access it like so:
  36. ```
  37. response = yield request
  38. ```
  39. A `return` (or reaching the end of the generator) will result in the
  40. client returning the last response obtained from the server.
  41. You can dispatch as many requests as is necessary.
  42. """
  43. yield request
  44. def sync_auth_flow(
  45. self, request: Request
  46. ) -> typing.Generator[Request, Response, None]:
  47. """
  48. Execute the authentication flow synchronously.
  49. By default, this defers to `.auth_flow()`. You should override this method
  50. when the authentication scheme does I/O and/or uses concurrency primitives.
  51. """
  52. if self.requires_request_body:
  53. request.read()
  54. flow = self.auth_flow(request)
  55. request = next(flow)
  56. while True:
  57. response = yield request
  58. if self.requires_response_body:
  59. response.read()
  60. try:
  61. request = flow.send(response)
  62. except StopIteration:
  63. break
  64. async def async_auth_flow(
  65. self, request: Request
  66. ) -> typing.AsyncGenerator[Request, Response]:
  67. """
  68. Execute the authentication flow asynchronously.
  69. By default, this defers to `.auth_flow()`. You should override this method
  70. when the authentication scheme does I/O and/or uses concurrency primitives.
  71. """
  72. if self.requires_request_body:
  73. await request.aread()
  74. flow = self.auth_flow(request)
  75. request = next(flow)
  76. while True:
  77. response = yield request
  78. if self.requires_response_body:
  79. await response.aread()
  80. try:
  81. request = flow.send(response)
  82. except StopIteration:
  83. break
  84. class FunctionAuth(Auth):
  85. """
  86. Allows the 'auth' argument to be passed as a simple callable function,
  87. that takes the request, and returns a new, modified request.
  88. """
  89. def __init__(self, func: typing.Callable[[Request], Request]) -> None:
  90. self._func = func
  91. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  92. yield self._func(request)
  93. class BasicAuth(Auth):
  94. """
  95. Allows the 'auth' argument to be passed as a (username, password) pair,
  96. and uses HTTP Basic authentication.
  97. """
  98. def __init__(self, username: str | bytes, password: str | bytes) -> None:
  99. self._auth_header = self._build_auth_header(username, password)
  100. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  101. request.headers["Authorization"] = self._auth_header
  102. yield request
  103. def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
  104. userpass = b":".join((to_bytes(username), to_bytes(password)))
  105. token = b64encode(userpass).decode()
  106. return f"Basic {token}"
  107. class NetRCAuth(Auth):
  108. """
  109. Use a 'netrc' file to lookup basic auth credentials based on the url host.
  110. """
  111. def __init__(self, file: str | None = None) -> None:
  112. # Lazily import 'netrc'.
  113. # There's no need for us to load this module unless 'NetRCAuth' is being used.
  114. import netrc
  115. self._netrc_info = netrc.netrc(file)
  116. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  117. auth_info = self._netrc_info.authenticators(request.url.host)
  118. if auth_info is None or not auth_info[2]:
  119. # The netrc file did not have authentication credentials for this host.
  120. yield request
  121. else:
  122. # Build a basic auth header with credentials from the netrc file.
  123. request.headers["Authorization"] = self._build_auth_header(
  124. username=auth_info[0], password=auth_info[2]
  125. )
  126. yield request
  127. def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
  128. userpass = b":".join((to_bytes(username), to_bytes(password)))
  129. token = b64encode(userpass).decode()
  130. return f"Basic {token}"
  131. class DigestAuth(Auth):
  132. _ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = {
  133. "MD5": hashlib.md5,
  134. "MD5-SESS": hashlib.md5,
  135. "SHA": hashlib.sha1,
  136. "SHA-SESS": hashlib.sha1,
  137. "SHA-256": hashlib.sha256,
  138. "SHA-256-SESS": hashlib.sha256,
  139. "SHA-512": hashlib.sha512,
  140. "SHA-512-SESS": hashlib.sha512,
  141. }
  142. def __init__(self, username: str | bytes, password: str | bytes) -> None:
  143. self._username = to_bytes(username)
  144. self._password = to_bytes(password)
  145. self._last_challenge: _DigestAuthChallenge | None = None
  146. self._nonce_count = 1
  147. def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
  148. if self._last_challenge:
  149. request.headers["Authorization"] = self._build_auth_header(
  150. request, self._last_challenge
  151. )
  152. response = yield request
  153. if response.status_code != 401 or "www-authenticate" not in response.headers:
  154. # If the response is not a 401 then we don't
  155. # need to build an authenticated request.
  156. return
  157. for auth_header in response.headers.get_list("www-authenticate"):
  158. if auth_header.lower().startswith("digest "):
  159. break
  160. else:
  161. # If the response does not include a 'WWW-Authenticate: Digest ...'
  162. # header, then we don't need to build an authenticated request.
  163. return
  164. self._last_challenge = self._parse_challenge(request, response, auth_header)
  165. self._nonce_count = 1
  166. request.headers["Authorization"] = self._build_auth_header(
  167. request, self._last_challenge
  168. )
  169. if response.cookies:
  170. Cookies(response.cookies).set_cookie_header(request=request)
  171. yield request
  172. def _parse_challenge(
  173. self, request: Request, response: Response, auth_header: str
  174. ) -> _DigestAuthChallenge:
  175. """
  176. Returns a challenge from a Digest WWW-Authenticate header.
  177. These take the form of:
  178. `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
  179. """
  180. scheme, _, fields = auth_header.partition(" ")
  181. # This method should only ever have been called with a Digest auth header.
  182. assert scheme.lower() == "digest"
  183. header_dict: dict[str, str] = {}
  184. for field in parse_http_list(fields):
  185. key, value = field.strip().split("=", 1)
  186. header_dict[key] = unquote(value)
  187. try:
  188. realm = header_dict["realm"].encode()
  189. nonce = header_dict["nonce"].encode()
  190. algorithm = header_dict.get("algorithm", "MD5")
  191. opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
  192. qop = header_dict["qop"].encode() if "qop" in header_dict else None
  193. return _DigestAuthChallenge(
  194. realm=realm, nonce=nonce, algorithm=algorithm, opaque=opaque, qop=qop
  195. )
  196. except KeyError as exc:
  197. message = "Malformed Digest WWW-Authenticate header"
  198. raise ProtocolError(message, request=request) from exc
  199. def _build_auth_header(
  200. self, request: Request, challenge: _DigestAuthChallenge
  201. ) -> str:
  202. hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
  203. def digest(data: bytes) -> bytes:
  204. return hash_func(data).hexdigest().encode()
  205. A1 = b":".join((self._username, challenge.realm, self._password))
  206. path = request.url.raw_path
  207. A2 = b":".join((request.method.encode(), path))
  208. # TODO: implement auth-int
  209. HA2 = digest(A2)
  210. nc_value = b"%08x" % self._nonce_count
  211. cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
  212. self._nonce_count += 1
  213. HA1 = digest(A1)
  214. if challenge.algorithm.lower().endswith("-sess"):
  215. HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))
  216. qop = self._resolve_qop(challenge.qop, request=request)
  217. if qop is None:
  218. # Following RFC 2069
  219. digest_data = [HA1, challenge.nonce, HA2]
  220. else:
  221. # Following RFC 2617/7616
  222. digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2]
  223. format_args = {
  224. "username": self._username,
  225. "realm": challenge.realm,
  226. "nonce": challenge.nonce,
  227. "uri": path,
  228. "response": digest(b":".join(digest_data)),
  229. "algorithm": challenge.algorithm.encode(),
  230. }
  231. if challenge.opaque:
  232. format_args["opaque"] = challenge.opaque
  233. if qop:
  234. format_args["qop"] = b"auth"
  235. format_args["nc"] = nc_value
  236. format_args["cnonce"] = cnonce
  237. return "Digest " + self._get_header_value(format_args)
  238. def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
  239. s = str(nonce_count).encode()
  240. s += nonce
  241. s += time.ctime().encode()
  242. s += os.urandom(8)
  243. return hashlib.sha1(s).hexdigest()[:16].encode()
  244. def _get_header_value(self, header_fields: dict[str, bytes]) -> str:
  245. NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
  246. QUOTED_TEMPLATE = '{}="{}"'
  247. NON_QUOTED_TEMPLATE = "{}={}"
  248. header_value = ""
  249. for i, (field, value) in enumerate(header_fields.items()):
  250. if i > 0:
  251. header_value += ", "
  252. template = (
  253. QUOTED_TEMPLATE
  254. if field not in NON_QUOTED_FIELDS
  255. else NON_QUOTED_TEMPLATE
  256. )
  257. header_value += template.format(field, to_str(value))
  258. return header_value
  259. def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None:
  260. if qop is None:
  261. return None
  262. qops = re.split(b", ?", qop)
  263. if b"auth" in qops:
  264. return b"auth"
  265. if qops == [b"auth-int"]:
  266. raise NotImplementedError("Digest auth-int support is not yet implemented")
  267. message = f'Unexpected qop value "{qop!r}" in digest auth'
  268. raise ProtocolError(message, request=request)
  269. class _DigestAuthChallenge(typing.NamedTuple):
  270. realm: bytes
  271. nonce: bytes
  272. algorithm: str
  273. opaque: bytes | None
  274. qop: bytes | None