path.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import sqlalchemy as sa
  2. from sqlalchemy.orm.attributes import InstrumentedAttribute
  3. from sqlalchemy.util.langhelpers import symbol
  4. from .utils import str_coercible
  5. @str_coercible
  6. class Path:
  7. def __init__(self, path, separator='.'):
  8. if isinstance(path, Path):
  9. self.path = path.path
  10. else:
  11. self.path = path
  12. self.separator = separator
  13. @property
  14. def parts(self):
  15. return self.path.split(self.separator)
  16. def __iter__(self):
  17. yield from self.parts
  18. def __len__(self):
  19. return len(self.parts)
  20. def __repr__(self):
  21. return f"{self.__class__.__name__}('{self.path}')"
  22. def index(self, element):
  23. return self.parts.index(element)
  24. def __getitem__(self, slice):
  25. result = self.parts[slice]
  26. if isinstance(result, list):
  27. return self.__class__(
  28. self.separator.join(result),
  29. separator=self.separator
  30. )
  31. return result
  32. def __eq__(self, other):
  33. return self.path == other.path and self.separator == other.separator
  34. def __ne__(self, other):
  35. return not (self == other)
  36. def __unicode__(self):
  37. return self.path
  38. def get_attr(mixed, attr):
  39. if isinstance(mixed, InstrumentedAttribute):
  40. return getattr(
  41. mixed.property.mapper.class_,
  42. attr
  43. )
  44. else:
  45. return getattr(mixed, attr)
  46. @str_coercible
  47. class AttrPath:
  48. def __init__(self, class_, path):
  49. self.class_ = class_
  50. self.path = Path(path)
  51. self.parts = []
  52. last_attr = class_
  53. for value in self.path:
  54. last_attr = get_attr(last_attr, value)
  55. self.parts.append(last_attr)
  56. def __iter__(self):
  57. yield from self.parts
  58. def __invert__(self):
  59. def get_backref(part):
  60. prop = part.property
  61. backref = prop.backref or prop.back_populates
  62. if backref is None:
  63. raise Exception(
  64. "Invert failed because property '%s' of class "
  65. "%s has no backref." % (
  66. prop.key,
  67. prop.parent.class_.__name__
  68. )
  69. )
  70. if isinstance(backref, tuple):
  71. return backref[0]
  72. else:
  73. return backref
  74. if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
  75. class_ = self.parts[-1].class_
  76. else:
  77. class_ = self.parts[-1].mapper.class_
  78. return self.__class__(
  79. class_,
  80. '.'.join(map(get_backref, reversed(self.parts)))
  81. )
  82. def index(self, element):
  83. for index, el in enumerate(self.parts):
  84. if el is element:
  85. return index
  86. @property
  87. def direction(self):
  88. symbols = [part.property.direction for part in self.parts]
  89. if symbol('MANYTOMANY') in symbols:
  90. return symbol('MANYTOMANY')
  91. elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
  92. return symbol('MANYTOMANY')
  93. return symbols[0]
  94. @property
  95. def uselist(self):
  96. return any(part.property.uselist for part in self.parts)
  97. def __getitem__(self, slice):
  98. result = self.parts[slice]
  99. if isinstance(result, list) and result:
  100. if result[0] is self.parts[0]:
  101. class_ = self.class_
  102. else:
  103. class_ = result[0].parent.class_
  104. return self.__class__(
  105. class_,
  106. self.path[slice]
  107. )
  108. else:
  109. return result
  110. def __len__(self):
  111. return len(self.path)
  112. def __repr__(self):
  113. return "{}({}, {!r})".format(
  114. self.__class__.__name__,
  115. self.class_.__name__,
  116. self.path.path
  117. )
  118. def __eq__(self, other):
  119. return self.path == other.path and self.class_ == other.class_
  120. def __ne__(self, other):
  121. return not (self == other)
  122. def __unicode__(self):
  123. return str(self.path)