addressing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # Copyright (c) "Neo4j"
  2. # Neo4j Sweden AB [https://neo4j.com]
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # https://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import annotations
  16. import typing as t
  17. from contextlib import suppress as _suppress
  18. from socket import (
  19. AddressFamily,
  20. AF_INET,
  21. AF_INET6,
  22. getservbyname,
  23. )
  24. if t.TYPE_CHECKING:
  25. import typing_extensions as te
  26. _T = t.TypeVar("_T")
  27. if t.TYPE_CHECKING:
  28. class _WithPeerName(te.Protocol):
  29. def getpeername(self) -> tuple: ...
  30. __all__ = [
  31. "Address",
  32. "IPv4Address",
  33. "IPv6Address",
  34. "ResolvedAddress",
  35. "ResolvedIPv4Address",
  36. "ResolvedIPv6Address",
  37. ]
  38. class _AddressMeta(type(tuple)): # type: ignore[misc]
  39. def __init__(cls, *args, **kwargs):
  40. super().__init__(*args, **kwargs)
  41. cls._ipv4_cls = None
  42. cls._ipv6_cls = None
  43. def _subclass_by_family(cls, family):
  44. subclasses = [
  45. sc
  46. for sc in cls.__subclasses__()
  47. if (
  48. sc.__module__ == cls.__module__
  49. and getattr(sc, "family", None) == family
  50. )
  51. ]
  52. if len(subclasses) != 1:
  53. raise ValueError(
  54. f"Class {cls} needs exactly one direct subclass with "
  55. f"attribute `family == {family}` within this module. "
  56. f"Found: {subclasses}"
  57. )
  58. return subclasses[0]
  59. @property
  60. def ipv4_cls(cls):
  61. if cls._ipv4_cls is None:
  62. cls._ipv4_cls = cls._subclass_by_family(AF_INET)
  63. return cls._ipv4_cls
  64. @property
  65. def ipv6_cls(cls):
  66. if cls._ipv6_cls is None:
  67. cls._ipv6_cls = cls._subclass_by_family(AF_INET6)
  68. return cls._ipv6_cls
  69. class Address(tuple, metaclass=_AddressMeta):
  70. """
  71. Base class to represent server addresses within the driver.
  72. A tuple of two (IPv4) or four (IPv6) elements, representing the address
  73. parts. See also python's :mod:`socket` module for more information.
  74. >>> Address(("example.com", 7687))
  75. IPv4Address(('example.com', 7687))
  76. >>> Address(("127.0.0.1", 7687))
  77. IPv4Address(('127.0.0.1', 7687))
  78. >>> Address(("::1", 7687, 0, 0))
  79. IPv6Address(('::1', 7687, 0, 0))
  80. :param iterable: A collection of two or four elements creating an
  81. :class:`.IPv4Address` or :class:`.IPv6Address` instance respectively.
  82. """
  83. #: Address family (:data:`socket.AF_INET` or :data:`socket.AF_INET6`).
  84. family: AddressFamily | None = None
  85. def __new__(cls, iterable: t.Collection) -> Address:
  86. if isinstance(iterable, cls):
  87. return iterable
  88. n_parts = len(iterable)
  89. inst = tuple.__new__(cls, iterable)
  90. if n_parts == 2:
  91. inst.__class__ = cls.ipv4_cls
  92. elif n_parts == 4:
  93. inst.__class__ = cls.ipv6_cls
  94. else:
  95. raise ValueError(
  96. "Addresses must consist of either "
  97. "two parts (IPv4) or four parts (IPv6)"
  98. )
  99. return inst
  100. @classmethod
  101. def from_socket(cls, socket: _WithPeerName) -> Address:
  102. """
  103. Create an address from a socket object.
  104. Uses the socket's ``getpeername`` method to retrieve the remote
  105. address the socket is connected to.
  106. """
  107. address = socket.getpeername()
  108. return cls(address)
  109. @classmethod
  110. def parse(
  111. cls,
  112. s: str,
  113. default_host: str | None = None,
  114. default_port: int | None = None,
  115. ) -> Address:
  116. """
  117. Parse a string into an address.
  118. The string must be in the format ``host:port`` (IPv4) or
  119. ``[host]:port`` (IPv6).
  120. If no port is specified, or is empty, ``default_port`` will be used.
  121. If no host is specified, or is empty, ``default_host`` will be used.
  122. >>> Address.parse("localhost:7687")
  123. IPv4Address(('localhost', 7687))
  124. >>> Address.parse("[::1]:7687")
  125. IPv6Address(('::1', 7687, 0, 0))
  126. >>> Address.parse("localhost")
  127. IPv4Address(('localhost', 0))
  128. >>> Address.parse("localhost", default_port=1234)
  129. IPv4Address(('localhost', 1234))
  130. :param s: The string to parse.
  131. :param default_host: The default host to use if none is specified.
  132. :data:`None` indicates to use ``"localhost"`` as default.
  133. :param default_port: The default port to use if none is specified.
  134. :data:`None` indicates to use ``0`` as default.
  135. :returns: The parsed address.
  136. """
  137. if not isinstance(s, str):
  138. raise TypeError("Address.parse requires a string argument")
  139. if s.startswith("["):
  140. # IPv6
  141. port: str | int
  142. host, _, port = s[1:].rpartition("]")
  143. port = port.lstrip(":")
  144. with _suppress(TypeError, ValueError):
  145. port = int(port)
  146. host = host or default_host or "localhost"
  147. port = port or default_port or 0
  148. return cls((host, port, 0, 0))
  149. else:
  150. # IPv4
  151. host, _, port = s.partition(":")
  152. with _suppress(TypeError, ValueError):
  153. port = int(port)
  154. host = host or default_host or "localhost"
  155. port = port or default_port or 0
  156. return cls((host, port))
  157. @classmethod
  158. def parse_list(
  159. cls,
  160. *s: str,
  161. default_host: str | None = None,
  162. default_port: int | None = None,
  163. ) -> list[Address]:
  164. """
  165. Parse multiple addresses into a list.
  166. See :meth:`.parse` for details on the string format.
  167. Either a whitespace-separated list of strings or multiple strings
  168. can be used.
  169. >>> Address.parse_list("localhost:7687", "[::1]:7687")
  170. [IPv4Address(('localhost', 7687)), IPv6Address(('::1', 7687, 0, 0))]
  171. >>> Address.parse_list("localhost:7687 [::1]:7687")
  172. [IPv4Address(('localhost', 7687)), IPv6Address(('::1', 7687, 0, 0))]
  173. :param s: The string(s) to parse.
  174. :param default_host: The default host to use if none is specified.
  175. :data:`None` indicates to use ``"localhost"`` as default.
  176. :param default_port: The default port to use if none is specified.
  177. :data:`None` indicates to use ``0`` as default.
  178. :returns: The list of parsed addresses.
  179. """ # noqa: E501 can't split the doctest lines
  180. if not all(isinstance(s0, str) for s0 in s):
  181. raise TypeError("Address.parse_list requires a string argument")
  182. return [
  183. cls.parse(a, default_host, default_port)
  184. for a in " ".join(s).split()
  185. ]
  186. def __repr__(self):
  187. return f"{self.__class__.__name__}({tuple(self)!r})"
  188. @property
  189. def _host_name(self) -> t.Any:
  190. return self[0]
  191. @property
  192. def host(self) -> t.Any:
  193. """
  194. The host part of the address.
  195. This is the first part of the address tuple.
  196. >>> Address(("localhost", 7687)).host
  197. 'localhost'
  198. """
  199. return self[0]
  200. @property
  201. def port(self) -> t.Any:
  202. """
  203. The port part of the address.
  204. This is the second part of the address tuple.
  205. >>> Address(("localhost", 7687)).port
  206. 7687
  207. >>> Address(("localhost", 7687, 0, 0)).port
  208. 7687
  209. >>> Address(("localhost", "7687")).port
  210. '7687'
  211. >>> Address(("localhost", "http")).port
  212. 'http'
  213. """
  214. return self[1]
  215. @property
  216. def _unresolved(self) -> Address:
  217. return self
  218. @property
  219. def port_number(self) -> int:
  220. """
  221. The port part of the address as an integer.
  222. First try to resolve the port as an integer, using
  223. :func:`socket.getservbyname`. If that fails, fall back to parsing the
  224. port as an integer.
  225. >>> Address(("localhost", 7687)).port_number
  226. 7687
  227. >>> Address(("localhost", "http")).port_number
  228. 80
  229. >>> Address(("localhost", "7687")).port_number
  230. 7687
  231. >>> Address(("localhost", [])).port_number
  232. Traceback (most recent call last):
  233. ...
  234. TypeError: Unknown port value []
  235. >>> Address(("localhost", "banana-protocol")).port_number
  236. Traceback (most recent call last):
  237. ...
  238. ValueError: Unknown port value 'banana-protocol'
  239. :returns: The resolved port number.
  240. :raise ValueError: If the port cannot be resolved.
  241. :raise TypeError: If the port cannot be resolved.
  242. """
  243. error_cls: type = TypeError
  244. try:
  245. return getservbyname(self[1])
  246. except OSError:
  247. # OSError: service/proto not found
  248. error_cls = ValueError
  249. except TypeError:
  250. # TypeError: getservbyname() argument 1 must be str, not X
  251. pass
  252. try:
  253. return int(self[1])
  254. except ValueError:
  255. error_cls = ValueError
  256. except TypeError:
  257. pass
  258. raise error_cls(f"Unknown port value {self[1]!r}")
  259. class IPv4Address(Address):
  260. """
  261. An IPv4 address (family ``AF_INET``).
  262. This class is also used for addresses that specify a host name instead of
  263. an IP address. E.g.,
  264. >>> Address(("example.com", 7687))
  265. IPv4Address(('example.com', 7687))
  266. This class should not be instantiated directly. Instead, use
  267. :class:`.Address` or one of its factory methods.
  268. """
  269. family = AF_INET
  270. def __str__(self) -> str:
  271. return "{}:{}".format(*self)
  272. class IPv6Address(Address):
  273. """
  274. An IPv6 address (family ``AF_INET6``).
  275. This class should not be instantiated directly. Instead, use
  276. :class:`.Address` or one of its factory methods.
  277. """
  278. family = AF_INET6
  279. def __str__(self) -> str:
  280. return "[{}]:{}".format(*self)
  281. # TODO: 6.0 - make this class private
  282. class ResolvedAddress(Address):
  283. _unresolved_host_name: str
  284. @property
  285. def _host_name(self) -> str:
  286. return self._unresolved_host_name
  287. @property
  288. def _unresolved(self) -> Address:
  289. return super().__new__(Address, (self._host_name, *self[1:]))
  290. def __new__(cls, iterable, *, host_name: str) -> ResolvedAddress:
  291. new = super().__new__(cls, iterable)
  292. new = t.cast(ResolvedAddress, new)
  293. new._unresolved_host_name = host_name
  294. return new
  295. # TODO: 6.0 - make this class private
  296. class ResolvedIPv4Address(IPv4Address, ResolvedAddress):
  297. pass
  298. # TODO: 6.0 - make this class private
  299. class ResolvedIPv6Address(IPv6Address, ResolvedAddress):
  300. pass