fields.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import warnings
  2. from marshmallow import fields
  3. from marshmallow.utils import is_iterable_but_not_string
  4. from sqlalchemy import inspect
  5. from sqlalchemy.orm.exc import NoResultFound
  6. def get_primary_keys(model):
  7. """Get primary key properties for a SQLAlchemy model.
  8. :param model: SQLAlchemy model class
  9. """
  10. mapper = model.__mapper__
  11. return [mapper.get_property_by_column(column) for column in mapper.primary_key]
  12. def ensure_list(value):
  13. return value if is_iterable_but_not_string(value) else [value]
  14. class RelatedList(fields.List):
  15. def get_value(self, obj, attr, accessor=None):
  16. # Do not call `fields.List`'s get_value as it calls the container's
  17. # `get_value` if the container has `attribute`.
  18. # Instead call the `get_value` from the parent of `fields.List`
  19. # so the special handling is avoided.
  20. return super(fields.List, self).get_value(obj, attr, accessor=accessor)
  21. class Related(fields.Field):
  22. """Related data represented by a SQLAlchemy `relationship`. Must be attached
  23. to a :class:`Schema` class whose options includes a SQLAlchemy `model`, such
  24. as :class:`SQLAlchemySchema`.
  25. :param list columns: Optional column names on related model. If not provided,
  26. the primary key(s) of the related model will be used.
  27. """
  28. default_error_messages = {
  29. "invalid": "Could not deserialize related value {value!r}; "
  30. "expected a dictionary with keys {keys!r}"
  31. }
  32. def __init__(self, columns=None, column=None, **kwargs):
  33. if column is not None:
  34. warnings.warn(
  35. "`column` parameter is deprecated and will be removed in future releases. "
  36. "Use `columns` instead.",
  37. DeprecationWarning,
  38. )
  39. if columns is None:
  40. columns = column
  41. super().__init__(**kwargs)
  42. self.columns = ensure_list(columns or [])
  43. @property
  44. def model(self):
  45. return self.root.opts.model
  46. @property
  47. def related_model(self):
  48. model_attr = getattr(self.model, self.attribute or self.name)
  49. if hasattr(model_attr, "remote_attr"): # handle association proxies
  50. model_attr = model_attr.remote_attr
  51. return model_attr.property.mapper.class_
  52. @property
  53. def related_keys(self):
  54. if self.columns:
  55. insp = inspect(self.related_model)
  56. return [insp.attrs[column] for column in self.columns]
  57. return get_primary_keys(self.related_model)
  58. @property
  59. def session(self):
  60. return self.root.session
  61. @property
  62. def transient(self):
  63. return self.root.transient
  64. def _serialize(self, value, attr, obj):
  65. ret = {prop.key: getattr(value, prop.key, None) for prop in self.related_keys}
  66. return ret if len(ret) > 1 else list(ret.values())[0]
  67. def _deserialize(self, value, *args, **kwargs):
  68. """Deserialize a serialized value to a model instance.
  69. If the parent schema is transient, create a new (transient) instance.
  70. Otherwise, attempt to find an existing instance in the database.
  71. :param value: The value to deserialize.
  72. """
  73. if not isinstance(value, dict):
  74. if len(self.related_keys) != 1:
  75. keys = [prop.key for prop in self.related_keys]
  76. raise self.make_error("invalid", value=value, keys=keys)
  77. value = {self.related_keys[0].key: value}
  78. if self.transient:
  79. return self.related_model(**value)
  80. try:
  81. result = self._get_existing_instance(
  82. self.session.query(self.related_model), value
  83. )
  84. except NoResultFound:
  85. # The related-object DNE in the DB, but we still want to deserialize it
  86. # ...perhaps we want to add it to the DB later
  87. return self.related_model(**value)
  88. return result
  89. def _get_existing_instance(self, query, value):
  90. """Retrieve the related object from an existing instance in the DB.
  91. :param query: A SQLAlchemy `Query <sqlalchemy.orm.query.Query>` object.
  92. :param value: The serialized value to mapto an existing instance.
  93. :raises NoResultFound: if there is no matching record.
  94. """
  95. if self.columns:
  96. result = query.filter_by(
  97. **{prop.key: value.get(prop.key) for prop in self.related_keys}
  98. ).one()
  99. else:
  100. # Use a faster path if the related key is the primary key.
  101. lookup_values = [value.get(prop.key) for prop in self.related_keys]
  102. try:
  103. result = query.get(lookup_values)
  104. except TypeError:
  105. keys = [prop.key for prop in self.related_keys]
  106. raise self.make_error("invalid", value=value, keys=keys)
  107. if result is None:
  108. raise NoResultFound
  109. return result
  110. class Nested(fields.Nested):
  111. """Nested field that inherits the session from its parent."""
  112. def _deserialize(self, *args, **kwargs):
  113. if hasattr(self.schema, "session"):
  114. self.schema.session = self.root.session
  115. self.schema.transient = self.root.transient
  116. return super()._deserialize(*args, **kwargs)