load_instance_mixin.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """Mixin that adds model instance loading behavior.
  2. .. warning::
  3. This module is treated as private API.
  4. Users should not need to use this module directly.
  5. """
  6. import marshmallow as ma
  7. from .fields import get_primary_keys
  8. class LoadInstanceMixin:
  9. class Opts:
  10. def __init__(self, meta, *args, **kwargs):
  11. super().__init__(meta, *args, **kwargs)
  12. self.sqla_session = getattr(meta, "sqla_session", None)
  13. self.load_instance = getattr(meta, "load_instance", False)
  14. self.transient = getattr(meta, "transient", False)
  15. class Schema:
  16. @property
  17. def session(self):
  18. return self._session or self.opts.sqla_session
  19. @session.setter
  20. def session(self, session):
  21. self._session = session
  22. @property
  23. def transient(self):
  24. if self._transient is not None:
  25. return self._transient
  26. return self.opts.transient
  27. @transient.setter
  28. def transient(self, transient):
  29. self._transient = transient
  30. def __init__(self, *args, **kwargs):
  31. self._session = kwargs.pop("session", None)
  32. self.instance = kwargs.pop("instance", None)
  33. self._transient = kwargs.pop("transient", None)
  34. self._load_instance = kwargs.pop("load_instance", self.opts.load_instance)
  35. super().__init__(*args, **kwargs)
  36. def get_instance(self, data):
  37. """Retrieve an existing record by primary key(s). If the schema instance
  38. is transient, return None.
  39. :param data: Serialized data to inform lookup.
  40. """
  41. if self.transient:
  42. return None
  43. props = get_primary_keys(self.opts.model)
  44. filters = {prop.key: data.get(prop.key) for prop in props}
  45. if None not in filters.values():
  46. return self.session.query(self.opts.model).filter_by(**filters).first()
  47. return None
  48. @ma.post_load
  49. def make_instance(self, data, **kwargs):
  50. """Deserialize data to an instance of the model if self.load_instance is True.
  51. Update an existing row if specified in `self.instance` or loaded by primary
  52. key(s) in the data; else create a new row.
  53. :param data: Data to deserialize.
  54. """
  55. if not self._load_instance:
  56. return data
  57. instance = self.instance or self.get_instance(data)
  58. if instance is not None:
  59. for key, value in data.items():
  60. setattr(instance, key, value)
  61. return instance
  62. kwargs, association_attrs = self._split_model_kwargs_association(data)
  63. instance = self.opts.model(**kwargs)
  64. for attr, value in association_attrs.items():
  65. setattr(instance, attr, value)
  66. return instance
  67. def load(self, data, *, session=None, instance=None, transient=False, **kwargs):
  68. """Deserialize data to internal representation.
  69. :param session: Optional SQLAlchemy session.
  70. :param instance: Optional existing instance to modify.
  71. :param transient: Optional switch to allow transient instantiation.
  72. """
  73. self._session = session or self._session
  74. self._transient = transient or self._transient
  75. if self._load_instance and not (self.transient or self.session):
  76. raise ValueError("Deserialization requires a session")
  77. self.instance = instance or self.instance
  78. try:
  79. return super().load(data, **kwargs)
  80. finally:
  81. self.instance = None
  82. def validate(self, data, *, session=None, **kwargs):
  83. self._session = session or self._session
  84. if not (self.transient or self.session):
  85. raise ValueError("Validation requires a session")
  86. return super().validate(data, **kwargs)
  87. def _split_model_kwargs_association(self, data):
  88. """Split serialized attrs to ensure association proxies are passed separately.
  89. This is necessary for Python < 3.6.0, as the order in which kwargs are passed
  90. is non-deterministic, and associations must be parsed by sqlalchemy after their
  91. intermediate relationship, unless their `creator` has been set.
  92. Ignore invalid keys at this point - behaviour for unknowns should be
  93. handled elsewhere.
  94. :param data: serialized dictionary of attrs to split on association_proxy.
  95. """
  96. association_attrs = {
  97. key: value
  98. for key, value in data.items()
  99. # association proxy
  100. if hasattr(getattr(self.opts.model, key, None), "remote_attr")
  101. }
  102. kwargs = {
  103. key: value
  104. for key, value in data.items()
  105. if (hasattr(self.opts.model, key) and key not in association_attrs)
  106. }
  107. return kwargs, association_attrs