auth_management.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # Copyright (c) "Neo4j"
  2. # Neo4j Sweden AB [https://neo4j.com]
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # https://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import annotations
  16. import abc
  17. import typing as t
  18. from logging import getLogger
  19. from .._async_compat.concurrency import (
  20. AsyncCooperativeLock,
  21. AsyncLock,
  22. )
  23. from .._auth_management import (
  24. AsyncAuthManager,
  25. AsyncClientCertificateProvider,
  26. ClientCertificate,
  27. expiring_auth_has_expired,
  28. ExpiringAuth,
  29. )
  30. if t.TYPE_CHECKING:
  31. from ..api import _TAuth
  32. from ..exceptions import Neo4jError
  33. log = getLogger("neo4j.auth_management")
  34. class AsyncStaticAuthManager(AsyncAuthManager):
  35. _auth: _TAuth
  36. def __init__(self, auth: _TAuth) -> None:
  37. self._auth = auth
  38. async def get_auth(self) -> _TAuth:
  39. return self._auth
  40. async def handle_security_exception(
  41. self, auth: _TAuth, error: Neo4jError
  42. ) -> bool:
  43. return False
  44. class AsyncNeo4jAuthTokenManager(AsyncAuthManager):
  45. _current_auth: ExpiringAuth | None
  46. _provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
  47. _handled_codes: frozenset[str]
  48. _lock: AsyncLock
  49. def __init__(
  50. self,
  51. provider: t.Callable[[], t.Awaitable[ExpiringAuth]],
  52. handled_codes: frozenset[str],
  53. ) -> None:
  54. self._provider = provider
  55. self._handled_codes = handled_codes
  56. self._current_auth = None
  57. self._lock = AsyncLock()
  58. async def _refresh_auth(self):
  59. try:
  60. self._current_auth = await self._provider()
  61. except BaseException as e:
  62. log.error("[ ] _: <AUTH MANAGER> provider failed: %r", e)
  63. raise
  64. if self._current_auth is None:
  65. raise TypeError(
  66. "Auth provider function passed to expiration_based "
  67. "AuthManager returned None, expected ExpiringAuth"
  68. )
  69. async def get_auth(self) -> _TAuth:
  70. async with self._lock:
  71. auth = self._current_auth
  72. if auth is None or expiring_auth_has_expired(auth):
  73. log.debug(
  74. "[ ] _: <AUTH MANAGER> refreshing (%s)",
  75. "init" if auth is None else "time out",
  76. )
  77. await self._refresh_auth()
  78. auth = self._current_auth
  79. assert auth is not None
  80. return auth.auth
  81. async def handle_security_exception(
  82. self, auth: _TAuth, error: Neo4jError
  83. ) -> bool:
  84. if error.code not in self._handled_codes:
  85. return False
  86. async with self._lock:
  87. cur_auth = self._current_auth
  88. if cur_auth is not None and cur_auth.auth == auth:
  89. log.debug(
  90. "[ ] _: <AUTH MANAGER> refreshing (error %s)",
  91. error.code,
  92. )
  93. await self._refresh_auth()
  94. return True
  95. class AsyncAuthManagers:
  96. """
  97. A collection of :class:`.AsyncAuthManager` factories.
  98. .. versionadded:: 5.8
  99. .. versionchanged:: 5.12
  100. * Method ``expiration_based()`` was renamed to :meth:`bearer`.
  101. * Added :meth:`basic`.
  102. .. versionchanged:: 5.14 Stabilized from preview.
  103. """
  104. @staticmethod
  105. def static(auth: _TAuth) -> AsyncAuthManager:
  106. """
  107. Create a static auth manager.
  108. The manager will always return the auth info provided at its creation.
  109. Example::
  110. # NOTE: this example is for illustration purposes only.
  111. # The driver will automatically wrap static auth info in a
  112. # static auth manager.
  113. import neo4j
  114. from neo4j.auth_management import AsyncAuthManagers
  115. auth = neo4j.basic_auth("neo4j", "password")
  116. with neo4j.GraphDatabase.driver(
  117. "neo4j://example.com:7687",
  118. auth=AsyncAuthManagers.static(auth)
  119. # auth=auth # this is equivalent
  120. ) as driver:
  121. ... # do stuff
  122. :param auth: The auth to return.
  123. :returns:
  124. An instance of an implementation of :class:`.AsyncAuthManager` that
  125. always returns the same auth.
  126. .. versionadded:: 5.8
  127. .. versionchanged:: 5.14 Stabilized from preview.
  128. """
  129. return AsyncStaticAuthManager(auth)
  130. @staticmethod
  131. def basic(
  132. provider: t.Callable[[], t.Awaitable[_TAuth]],
  133. ) -> AsyncAuthManager:
  134. """
  135. Create an auth manager handling basic auth password rotation.
  136. This factory wraps the provider function in an auth manager
  137. implementation that caches the provided auth info until the server
  138. notifies the driver that the auth info has expired (by returning
  139. an error that indicates that the password is invalid).
  140. Note that this implies that the provider function will be called again
  141. if it provides wrong auth info, potentially deferring failure due to a
  142. wrong password or username.
  143. .. warning::
  144. The provider function **must not** interact with the driver in any
  145. way as this can cause deadlocks and undefined behaviour.
  146. The provider function must only ever return auth information
  147. belonging to the same identity.
  148. Switching identities is undefined behavior.
  149. You may use :ref:`session-level authentication<session-auth-ref>`
  150. for such use-cases.
  151. Example::
  152. import neo4j
  153. from neo4j.auth_management import (
  154. AsyncAuthManagers,
  155. ExpiringAuth,
  156. )
  157. async def auth_provider():
  158. # some way of getting a token
  159. user, password = await get_current_auth()
  160. return (user, password)
  161. with neo4j.GraphDatabase.driver(
  162. "neo4j://example.com:7687",
  163. auth=AsyncAuthManagers.basic(auth_provider)
  164. ) as driver:
  165. ... # do stuff
  166. :param provider:
  167. A callable that provides new auth info whenever the server notifies
  168. the driver that the previous auth info is invalid.
  169. :returns:
  170. An instance of an implementation of :class:`.AsyncAuthManager` that
  171. returns auth info from the given provider and refreshes it, calling
  172. the provider again, when the auth info was rejected by the server.
  173. .. versionadded:: 5.12
  174. .. versionchanged:: 5.14 Stabilized from preview.
  175. """
  176. handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",))
  177. async def wrapped_provider() -> ExpiringAuth:
  178. return ExpiringAuth(await provider())
  179. return AsyncNeo4jAuthTokenManager(wrapped_provider, handled_codes)
  180. @staticmethod
  181. def bearer(
  182. provider: t.Callable[[], t.Awaitable[ExpiringAuth]],
  183. ) -> AsyncAuthManager:
  184. """
  185. Create an auth manager for potentially expiring bearer auth tokens.
  186. This factory wraps the provider function in an auth manager
  187. implementation that caches the provided auth info until either the
  188. :attr:`.ExpiringAuth.expires_at` exceeded or the server notified the
  189. driver that the auth info has expired (by returning an error that
  190. indicates that the bearer auth token has expired).
  191. .. warning::
  192. The provider function **must not** interact with the driver in any
  193. way as this can cause deadlocks and undefined behaviour.
  194. The provider function must only ever return auth information
  195. belonging to the same identity.
  196. Switching identities is undefined behavior.
  197. You may use :ref:`session-level authentication<session-auth-ref>`
  198. for such use-cases.
  199. Example::
  200. import neo4j
  201. from neo4j.auth_management import (
  202. AsyncAuthManagers,
  203. ExpiringAuth,
  204. )
  205. async def auth_provider():
  206. # some way of getting a token
  207. sso_token = await get_sso_token()
  208. # assume we know our tokens expire every 60 seconds
  209. expires_in = 60
  210. # Include a little buffer so that we fetch a new token
  211. # *before* the old one expires
  212. expires_in -= 10
  213. auth = neo4j.bearer_auth(sso_token)
  214. return ExpiringAuth(auth=auth).expires_in(expires_in)
  215. with neo4j.GraphDatabase.driver(
  216. "neo4j://example.com:7687",
  217. auth=AsyncAuthManagers.bearer(auth_provider)
  218. ) as driver:
  219. ... # do stuff
  220. :param provider:
  221. A callable that provides a :class:`.ExpiringAuth` instance.
  222. :returns:
  223. An instance of an implementation of :class:`.AsyncAuthManager` that
  224. returns auth info from the given provider and refreshes it, calling
  225. the provider again, when the auth info expires (either because it's
  226. reached its expiry time or because the server flagged it as
  227. expired).
  228. .. versionadded:: 5.12
  229. .. versionchanged:: 5.14 Stabilized from preview.
  230. """
  231. handled_codes = frozenset(
  232. (
  233. "Neo.ClientError.Security.TokenExpired",
  234. "Neo.ClientError.Security.Unauthorized",
  235. )
  236. )
  237. return AsyncNeo4jAuthTokenManager(provider, handled_codes)
  238. class _AsyncStaticClientCertificateProvider(AsyncClientCertificateProvider):
  239. _cert: ClientCertificate | None
  240. def __init__(self, cert: ClientCertificate) -> None:
  241. self._cert = cert
  242. async def get_certificate(self) -> ClientCertificate | None:
  243. cert, self._cert = self._cert, None
  244. return cert
  245. class AsyncRotatingClientCertificateProvider(AsyncClientCertificateProvider):
  246. """
  247. Abstract base class for certificate providers that can rotate certificates.
  248. The provider will make the driver use the initial certificate for all
  249. connections until the certificate is updated using the
  250. :meth:`update_certificate` method.
  251. From that point on, the new certificate will be used for all new
  252. connections until :meth:`update_certificate` is called again and so on.
  253. Example::
  254. from neo4j import AsyncGraphDatabase
  255. from neo4j.auth_management import (
  256. ClientCertificate,
  257. AsyncClientCertificateProviders,
  258. )
  259. provider = AsyncClientCertificateProviders.rotating(
  260. ClientCertificate(
  261. certfile="path/to/certfile.pem",
  262. keyfile="path/to/keyfile.pem",
  263. password=lambda: "super_secret_password"
  264. )
  265. )
  266. driver = AsyncGraphDatabase.driver(
  267. # secure driver must be configured for client certificate
  268. # to be used: (...+s[sc] scheme or encrypted=True)
  269. "neo4j+s://example.com:7687",
  270. # auth still required as before, unless server is configured to not
  271. # use authentication
  272. auth=("neo4j", "password"),
  273. client_certificate=provider
  274. )
  275. # do work with the driver, until the certificate needs to be rotated
  276. ...
  277. await provider.update_certificate(
  278. ClientCertificate(
  279. certfile="path/to/new/certfile.pem",
  280. keyfile="path/to/new/keyfile.pem",
  281. password=lambda: "new_super_secret_password"
  282. )
  283. )
  284. # do more work with the driver, until the certificate needs to be
  285. # rotated again
  286. ...
  287. .. versionadded:: 5.19
  288. .. versionchanged:: 5.24
  289. Turned this class into an abstract class to make the actual
  290. implementation internal. This entails removing the possibility to
  291. directly instantiate this class. Please use the factory method
  292. :meth:`.AsyncClientCertificateProviders.rotating` instead.
  293. .. versionchanged:: 5.27 Stabilized from preview.
  294. """
  295. @abc.abstractmethod
  296. async def update_certificate(self, cert: ClientCertificate) -> None:
  297. """Update the certificate to use for new connections."""
  298. class _AsyncNeo4jRotatingClientCertificateProvider(
  299. AsyncRotatingClientCertificateProvider
  300. ):
  301. def __init__(self, initial_cert: ClientCertificate) -> None:
  302. self._cert: ClientCertificate | None = initial_cert
  303. self._lock = AsyncCooperativeLock()
  304. async def get_certificate(self) -> ClientCertificate | None:
  305. async with self._lock:
  306. cert, self._cert = self._cert, None
  307. return cert
  308. async def update_certificate(self, cert: ClientCertificate) -> None:
  309. async with self._lock:
  310. self._cert = cert
  311. class AsyncClientCertificateProviders:
  312. """
  313. A collection of :class:`.AsyncClientCertificateProvider` factories.
  314. .. versionadded:: 5.19
  315. .. versionchanged:: 5.27 Stabilized from preview.
  316. """
  317. @staticmethod
  318. def static(cert: ClientCertificate) -> AsyncClientCertificateProvider:
  319. """
  320. Create a static client certificate provider.
  321. The provider simply makes the driver use the given certificate for all
  322. connections.
  323. """
  324. return _AsyncStaticClientCertificateProvider(cert)
  325. @staticmethod
  326. def rotating(
  327. initial_cert: ClientCertificate,
  328. ) -> AsyncRotatingClientCertificateProvider:
  329. """
  330. Create certificate provider that allows for rotating certificates.
  331. .. seealso:: :class:`.AsyncRotatingClientCertificateProvider`
  332. """
  333. return _AsyncNeo4jRotatingClientCertificateProvider(initial_cert)