_protocol.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import os
  3. import re
  4. from pathlib import PurePath
  5. from typing import TYPE_CHECKING
  6. from typing import Any
  7. if TYPE_CHECKING:
  8. from upath.core import UPath
  9. __all__ = [
  10. "get_upath_protocol",
  11. "normalize_empty_netloc",
  12. "compatible_protocol",
  13. ]
  14. # Regular expression to match fsspec style protocols.
  15. # Matches single slash usage too for compatibility.
  16. _PROTOCOL_RE = re.compile(
  17. r"^(?P<protocol>[A-Za-z][A-Za-z0-9+]+):(?P<slashes>//?)(?P<path>.*)"
  18. )
  19. # Matches data URIs
  20. _DATA_URI_RE = re.compile(r"^data:[^,]*,")
  21. def _match_protocol(pth: str) -> str:
  22. if m := _PROTOCOL_RE.match(pth):
  23. return m.group("protocol")
  24. elif _DATA_URI_RE.match(pth):
  25. return "data"
  26. return ""
  27. def get_upath_protocol(
  28. pth: str | PurePath | os.PathLike,
  29. *,
  30. protocol: str | None = None,
  31. storage_options: dict[str, Any] | None = None,
  32. ) -> str:
  33. """return the filesystem spec protocol"""
  34. if isinstance(pth, str):
  35. pth_protocol = _match_protocol(pth)
  36. elif isinstance(pth, PurePath):
  37. pth_protocol = getattr(pth, "protocol", "")
  38. elif hasattr(pth, "__fspath__"):
  39. pth_protocol = _match_protocol(pth.__fspath__())
  40. else:
  41. pth_protocol = _match_protocol(str(pth))
  42. # if storage_options and not protocol and not pth_protocol:
  43. # protocol = "file"
  44. if protocol and pth_protocol and not pth_protocol.startswith(protocol):
  45. raise ValueError(
  46. f"requested protocol {protocol!r} incompatible with {pth_protocol!r}"
  47. )
  48. return protocol or pth_protocol or ""
  49. def normalize_empty_netloc(pth: str) -> str:
  50. if m := _PROTOCOL_RE.match(pth):
  51. if len(m.group("slashes")) == 1:
  52. protocol = m.group("protocol")
  53. path = m.group("path")
  54. pth = f"{protocol}:///{path}"
  55. return pth
  56. def compatible_protocol(protocol: str, *args: str | os.PathLike[str] | UPath) -> bool:
  57. """check if UPath protocols are compatible"""
  58. for arg in args:
  59. other_protocol = get_upath_protocol(arg)
  60. # consider protocols equivalent if they match up to the first "+"
  61. other_protocol = other_protocol.partition("+")[0]
  62. # protocols: only identical (or empty "") protocols can combine
  63. if other_protocol and other_protocol != protocol:
  64. return False
  65. return True