asyncbackend.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. from typing import Dict
  3. import dns.exception
  4. # pylint: disable=unused-import
  5. from dns._asyncbackend import ( # noqa: F401 lgtm[py/unused-import]
  6. Backend,
  7. DatagramSocket,
  8. Socket,
  9. StreamSocket,
  10. )
  11. # pylint: enable=unused-import
  12. _default_backend = None
  13. _backends: Dict[str, Backend] = {}
  14. # Allow sniffio import to be disabled for testing purposes
  15. _no_sniffio = False
  16. class AsyncLibraryNotFoundError(dns.exception.DNSException):
  17. pass
  18. def get_backend(name: str) -> Backend:
  19. """Get the specified asynchronous backend.
  20. *name*, a ``str``, the name of the backend. Currently the "trio"
  21. and "asyncio" backends are available.
  22. Raises NotImplementedError if an unknown backend name is specified.
  23. """
  24. # pylint: disable=import-outside-toplevel,redefined-outer-name
  25. backend = _backends.get(name)
  26. if backend:
  27. return backend
  28. if name == "trio":
  29. import dns._trio_backend
  30. backend = dns._trio_backend.Backend()
  31. elif name == "asyncio":
  32. import dns._asyncio_backend
  33. backend = dns._asyncio_backend.Backend()
  34. else:
  35. raise NotImplementedError(f"unimplemented async backend {name}")
  36. _backends[name] = backend
  37. return backend
  38. def sniff() -> str:
  39. """Attempt to determine the in-use asynchronous I/O library by using
  40. the ``sniffio`` module if it is available.
  41. Returns the name of the library, or raises AsyncLibraryNotFoundError
  42. if the library cannot be determined.
  43. """
  44. # pylint: disable=import-outside-toplevel
  45. try:
  46. if _no_sniffio:
  47. raise ImportError
  48. import sniffio
  49. try:
  50. return sniffio.current_async_library()
  51. except sniffio.AsyncLibraryNotFoundError:
  52. raise AsyncLibraryNotFoundError("sniffio cannot determine async library")
  53. except ImportError:
  54. import asyncio
  55. try:
  56. asyncio.get_running_loop()
  57. return "asyncio"
  58. except RuntimeError:
  59. raise AsyncLibraryNotFoundError("no async library detected")
  60. def get_default_backend() -> Backend:
  61. """Get the default backend, initializing it if necessary."""
  62. if _default_backend:
  63. return _default_backend
  64. return set_default_backend(sniff())
  65. def set_default_backend(name: str) -> Backend:
  66. """Set the default backend.
  67. It's not normally necessary to call this method, as
  68. ``get_default_backend()`` will initialize the backend
  69. appropriately in many cases. If ``sniffio`` is not installed, or
  70. in testing situations, this function allows the backend to be set
  71. explicitly.
  72. """
  73. global _default_backend
  74. _default_backend = get_backend(name)
  75. return _default_backend