generic.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from collections.abc import Iterable
  2. import sqlalchemy as sa
  3. from sqlalchemy.ext.hybrid import hybrid_property
  4. from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
  5. from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
  6. from sqlalchemy.orm.session import _state_session
  7. from sqlalchemy.util import set_creation_order
  8. from .exceptions import ImproperlyConfigured
  9. from .functions import identity
  10. from .functions.orm import _get_class_registry
  11. class GenericAttributeImpl(attributes.ScalarAttributeImpl):
  12. def __init__(self, *args, **kwargs):
  13. """
  14. The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22,
  15. adding a 'default_function' required positional argument before 'dispatch'.
  16. This adjustment ensures compatibility across versions by inserting None for
  17. 'default_function' in versions >= 2.0.22.
  18. Arguments received: (class, key, dispatch)
  19. Required by AttributeImpl: (class, key, default_function, dispatch)
  20. Setting None as default_function here.
  21. """
  22. # Adjust for SQLAlchemy version change
  23. sqlalchemy_version = tuple(map(int, sa.__version__.split('.')))
  24. if sqlalchemy_version >= (2, 0, 22):
  25. args = (*args[:2], None, *args[2:])
  26. super().__init__(*args, **kwargs)
  27. def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
  28. if self.key in dict_:
  29. return dict_[self.key]
  30. # Retrieve the session bound to the state in order to perform
  31. # a lazy query for the attribute.
  32. session = _state_session(state)
  33. if session is None:
  34. # State is not bound to a session; we cannot proceed.
  35. return None
  36. # Find class for discriminator.
  37. # TODO: Perhaps optimize with some sort of lookup?
  38. discriminator = self.get_state_discriminator(state)
  39. target_class = _get_class_registry(state.class_).get(discriminator)
  40. if target_class is None:
  41. # Unknown discriminator; return nothing.
  42. return None
  43. id = self.get_state_id(state)
  44. try:
  45. target = session.get(target_class, id)
  46. except AttributeError:
  47. # sqlalchemy 1.3
  48. target = session.query(target_class).get(id)
  49. # Return found (or not found) target.
  50. return target
  51. def get_state_discriminator(self, state):
  52. discriminator = self.parent_token.discriminator
  53. if isinstance(discriminator, hybrid_property):
  54. return getattr(state.obj(), discriminator.__name__)
  55. else:
  56. return state.attrs[discriminator.key].value
  57. def get_state_id(self, state):
  58. # Lookup row with the discriminator and id.
  59. return tuple(state.attrs[id.key].value for id in self.parent_token.id)
  60. def set(self, state, dict_, initiator,
  61. passive=attributes.PASSIVE_OFF,
  62. check_old=None,
  63. pop=False):
  64. # Set us on the state.
  65. dict_[self.key] = initiator
  66. if initiator is None:
  67. # Nullify relationship args
  68. for id in self.parent_token.id:
  69. dict_[id.key] = None
  70. dict_[self.parent_token.discriminator.key] = None
  71. else:
  72. # Get the primary key of the initiator and ensure we
  73. # can support this assignment.
  74. class_ = type(initiator)
  75. mapper = class_mapper(class_)
  76. pk = mapper.identity_key_from_instance(initiator)[1]
  77. # Set the identifier and the discriminator.
  78. discriminator = class_.__name__
  79. for index, id in enumerate(self.parent_token.id):
  80. dict_[id.key] = pk[index]
  81. dict_[self.parent_token.discriminator.key] = discriminator
  82. class GenericRelationshipProperty(MapperProperty):
  83. """A generic form of the relationship property.
  84. Creates a 1 to many relationship between the parent model
  85. and any other models using a descriminator (the table name).
  86. :param discriminator
  87. Field to discriminate which model we are referring to.
  88. :param id:
  89. Field to point to the model we are referring to.
  90. """
  91. def __init__(self, discriminator, id, doc=None):
  92. super().__init__()
  93. self._discriminator_col = discriminator
  94. self._id_cols = id
  95. self._id = None
  96. self._discriminator = None
  97. self.doc = doc
  98. set_creation_order(self)
  99. def _column_to_property(self, column):
  100. if isinstance(column, hybrid_property):
  101. attr_key = column.__name__
  102. for key, attr in self.parent.all_orm_descriptors.items():
  103. if key == attr_key:
  104. return attr
  105. else:
  106. for attr in self.parent.attrs.values():
  107. if isinstance(attr, ColumnProperty):
  108. if attr.columns[0].name == column.name:
  109. return attr
  110. def init(self):
  111. def convert_strings(column):
  112. if isinstance(column, str):
  113. return self.parent.columns[column]
  114. return column
  115. self._discriminator_col = convert_strings(self._discriminator_col)
  116. self._id_cols = convert_strings(self._id_cols)
  117. if isinstance(self._id_cols, Iterable):
  118. self._id_cols = list(map(convert_strings, self._id_cols))
  119. else:
  120. self._id_cols = [self._id_cols]
  121. self.discriminator = self._column_to_property(self._discriminator_col)
  122. if self.discriminator is None:
  123. raise ImproperlyConfigured(
  124. 'Could not find discriminator descriptor.'
  125. )
  126. self.id = list(map(self._column_to_property, self._id_cols))
  127. class Comparator(PropComparator):
  128. def __init__(self, prop, parentmapper):
  129. self.property = prop
  130. self._parententity = parentmapper
  131. def __eq__(self, other):
  132. discriminator = type(other).__name__
  133. q = self.property._discriminator_col == discriminator
  134. other_id = identity(other)
  135. for index, id in enumerate(self.property._id_cols):
  136. q &= id == other_id[index]
  137. return q
  138. def __ne__(self, other):
  139. return ~(self == other)
  140. def is_type(self, other):
  141. mapper = sa.inspect(other)
  142. # Iterate through the weak sequence in order to get the actual
  143. # mappers
  144. class_names = [other.__name__]
  145. class_names.extend([
  146. submapper.class_.__name__
  147. for submapper in mapper._inheriting_mappers
  148. ])
  149. return self.property._discriminator_col.in_(class_names)
  150. def instrument_class(self, mapper):
  151. attributes.register_attribute(
  152. mapper.class_,
  153. self.key,
  154. comparator=self.Comparator(self, mapper),
  155. parententity=mapper,
  156. doc=self.doc,
  157. impl_class=GenericAttributeImpl,
  158. parent_token=self
  159. )
  160. def generic_relationship(*args, **kwargs):
  161. return GenericRelationshipProperty(*args, **kwargs)