ltree.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import re
  2. from ..utils import str_coercible
  3. path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
  4. @str_coercible
  5. class Ltree:
  6. """
  7. Ltree class wraps a valid string label path. It provides various
  8. convenience properties and methods.
  9. ::
  10. from sqlalchemy_utils import Ltree
  11. Ltree('1.2.3').path # '1.2.3'
  12. Ltree always validates the given path.
  13. ::
  14. Ltree(None) # raises TypeError
  15. Ltree('..') # raises ValueError
  16. Validator is also available as class method.
  17. ::
  18. Ltree.validate('1.2.3')
  19. Ltree.validate(None) # raises TypeError
  20. Ltree supports equality operators.
  21. ::
  22. Ltree('Countries.Finland') == Ltree('Countries.Finland')
  23. Ltree('Countries.Germany') != Ltree('Countries.Finland')
  24. Ltree objects are hashable.
  25. ::
  26. assert hash(Ltree('Finland')) == hash('Finland')
  27. Ltree objects have length.
  28. ::
  29. assert len(Ltree('1.2')) == 2
  30. assert len(Ltree('some.one.some.where')) # 4
  31. You can easily find subpath indexes.
  32. ::
  33. assert Ltree('1.2.3').index('2.3') == 1
  34. assert Ltree('1.2.3.4.5').index('3.4') == 2
  35. Ltree objects can be sliced.
  36. ::
  37. assert Ltree('1.2.3')[0:2] == Ltree('1.2')
  38. assert Ltree('1.2.3')[1:] == Ltree('2.3')
  39. Finding longest common ancestor.
  40. ::
  41. assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
  42. assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
  43. Ltree objects can be concatenated.
  44. ::
  45. assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
  46. """
  47. def __init__(self, path_or_ltree):
  48. if isinstance(path_or_ltree, Ltree):
  49. self.path = path_or_ltree.path
  50. elif isinstance(path_or_ltree, str):
  51. self.validate(path_or_ltree)
  52. self.path = path_or_ltree
  53. else:
  54. raise TypeError(
  55. "Ltree() argument must be a string or an Ltree, not '{}'"
  56. .format(
  57. type(path_or_ltree).__name__
  58. )
  59. )
  60. @classmethod
  61. def validate(cls, path):
  62. if path_matcher.match(path) is None:
  63. raise ValueError(
  64. f"'{path}' is not a valid ltree path."
  65. )
  66. def __len__(self):
  67. return len(self.path.split('.'))
  68. def index(self, other):
  69. subpath = Ltree(other).path.split('.')
  70. parts = self.path.split('.')
  71. for index, _ in enumerate(parts):
  72. if parts[index:len(subpath) + index] == subpath:
  73. return index
  74. raise ValueError('subpath not found')
  75. def descendant_of(self, other):
  76. """
  77. is left argument a descendant of right (or equal)?
  78. ::
  79. assert Ltree('1.2.3.4.5').descendant_of('1.2.3')
  80. """
  81. subpath = self[:len(Ltree(other))]
  82. return subpath == other
  83. def ancestor_of(self, other):
  84. """
  85. is left argument an ancestor of right (or equal)?
  86. ::
  87. assert Ltree('1.2.3').ancestor_of('1.2.3.4.5')
  88. """
  89. subpath = Ltree(other)[:len(self)]
  90. return subpath == self
  91. def __getitem__(self, key):
  92. if isinstance(key, int):
  93. return Ltree(self.path.split('.')[key])
  94. elif isinstance(key, slice):
  95. return Ltree('.'.join(self.path.split('.')[key]))
  96. raise TypeError(
  97. 'Ltree indices must be integers, not {}'.format(
  98. key.__class__.__name__
  99. )
  100. )
  101. def lca(self, *others):
  102. """
  103. Lowest common ancestor, i.e., longest common prefix of paths
  104. ::
  105. assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
  106. """
  107. other_parts = [Ltree(other).path.split('.') for other in others]
  108. parts = self.path.split('.')
  109. for index, element in enumerate(parts):
  110. if any(
  111. other[index] != element or
  112. len(other) <= index + 1 or
  113. len(parts) == index + 1
  114. for other in other_parts
  115. ):
  116. if index == 0:
  117. return None
  118. return Ltree('.'.join(parts[0:index]))
  119. def __add__(self, other):
  120. return Ltree(self.path + '.' + Ltree(other).path)
  121. def __radd__(self, other):
  122. return Ltree(other) + self
  123. def __eq__(self, other):
  124. if isinstance(other, Ltree):
  125. return self.path == other.path
  126. elif isinstance(other, str):
  127. return self.path == other
  128. else:
  129. return NotImplemented
  130. def __hash__(self):
  131. return hash(self.path)
  132. def __ne__(self, other):
  133. return not (self == other)
  134. def __repr__(self):
  135. return f'{self.__class__.__name__}({self.path!r})'
  136. def __unicode__(self):
  137. return self.path
  138. def __contains__(self, label):
  139. return label in self.path.split('.')
  140. def __gt__(self, other):
  141. return self.path > other.path
  142. def __lt__(self, other):
  143. return self.path < other.path
  144. def __ge__(self, other):
  145. return self.path >= other.path
  146. def __le__(self, other):
  147. return self.path <= other.path