schema.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from marshmallow.fields import Field
  2. from marshmallow.schema import Schema, SchemaMeta, SchemaOpts
  3. import sqlalchemy as sa
  4. from sqlalchemy.ext.declarative import DeclarativeMeta
  5. from .convert import ModelConverter
  6. from .exceptions import IncorrectSchemaTypeError
  7. from .load_instance_mixin import LoadInstanceMixin
  8. # This isn't really a field; it's a placeholder for the metaclass.
  9. # This should be considered private API.
  10. class SQLAlchemyAutoField(Field):
  11. def __init__(self, *, column_name=None, model=None, table=None, field_kwargs):
  12. super().__init__()
  13. if model and table:
  14. raise ValueError("Cannot pass both `model` and `table` options.")
  15. self.column_name = column_name
  16. self.model = model
  17. self.table = table
  18. self.field_kwargs = field_kwargs
  19. def create_field(self, schema_opts, column_name, converter):
  20. model = self.model or schema_opts.model
  21. if model:
  22. return converter.field_for(model, column_name, **self.field_kwargs)
  23. else:
  24. table = self.table if self.table is not None else schema_opts.table
  25. column = getattr(table.columns, column_name)
  26. return converter.column2field(column, **self.field_kwargs)
  27. # This field should never be bound to a schema.
  28. # If this method is called, it's probably because the schema is not a SQLAlchemySchema.
  29. def _bind_to_schema(self, field_name, schema):
  30. raise IncorrectSchemaTypeError(
  31. f"Cannot bind SQLAlchemyAutoField. Make sure that {schema} is a SQLAlchemySchema or SQLAlchemyAutoSchema."
  32. )
  33. class SQLAlchemySchemaOpts(LoadInstanceMixin.Opts, SchemaOpts):
  34. """Options class for `SQLAlchemySchema`.
  35. Adds the following options:
  36. - ``model``: The SQLAlchemy model to generate the `Schema` from (mutually exclusive with ``table``).
  37. - ``table``: The SQLAlchemy table to generate the `Schema` from (mutually exclusive with ``model``).
  38. - ``load_instance``: Whether to load model instances.
  39. - ``sqla_session``: SQLAlchemy session to be used for deserialization.
  40. This is only needed when ``load_instance`` is `True`. You can also pass a session to the Schema's `load` method.
  41. - ``transient``: Whether to load model instances in a transient state (effectively ignoring the session).
  42. Only relevant when ``load_instance`` is `True`.
  43. - ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy model to marshmallow fields.
  44. """
  45. def __init__(self, meta, *args, **kwargs):
  46. super().__init__(meta, *args, **kwargs)
  47. self.model = getattr(meta, "model", None)
  48. self.table = getattr(meta, "table", None)
  49. if self.model is not None and self.table is not None:
  50. raise ValueError("Cannot set both `model` and `table` options.")
  51. self.model_converter = getattr(meta, "model_converter", ModelConverter)
  52. class SQLAlchemyAutoSchemaOpts(SQLAlchemySchemaOpts):
  53. """Options class for `SQLAlchemyAutoSchema`.
  54. Has the same options as `SQLAlchemySchemaOpts`, with the addition of:
  55. - ``include_fk``: Whether to include foreign fields; defaults to `False`.
  56. - ``include_relationships``: Whether to include relationships; defaults to `False`.
  57. """
  58. def __init__(self, meta, *args, **kwargs):
  59. super().__init__(meta, *args, **kwargs)
  60. self.include_fk = getattr(meta, "include_fk", False)
  61. self.include_relationships = getattr(meta, "include_relationships", False)
  62. if self.table is not None and self.include_relationships:
  63. raise ValueError("Cannot set `table` and `include_relationships = True`.")
  64. class SQLAlchemySchemaMeta(SchemaMeta):
  65. @classmethod
  66. def get_declared_fields(mcs, klass, cls_fields, inherited_fields, dict_cls):
  67. opts = klass.opts
  68. Converter = opts.model_converter
  69. converter = Converter(schema_cls=klass)
  70. fields = super().get_declared_fields(
  71. klass, cls_fields, inherited_fields, dict_cls
  72. )
  73. fields.update(mcs.get_declared_sqla_fields(fields, converter, opts, dict_cls))
  74. fields.update(mcs.get_auto_fields(fields, converter, opts, dict_cls))
  75. return fields
  76. @classmethod
  77. def get_declared_sqla_fields(mcs, base_fields, converter, opts, dict_cls):
  78. return {}
  79. @classmethod
  80. def get_auto_fields(mcs, fields, converter, opts, dict_cls):
  81. return dict_cls(
  82. {
  83. field_name: field.create_field(
  84. opts, field.column_name or field_name, converter
  85. )
  86. for field_name, field in fields.items()
  87. if isinstance(field, SQLAlchemyAutoField)
  88. and field_name not in opts.exclude
  89. }
  90. )
  91. class SQLAlchemyAutoSchemaMeta(SQLAlchemySchemaMeta):
  92. @classmethod
  93. def get_declared_sqla_fields(cls, base_fields, converter, opts, dict_cls):
  94. fields = dict_cls()
  95. if opts.table is not None:
  96. fields.update(
  97. converter.fields_for_table(
  98. opts.table,
  99. fields=opts.fields,
  100. exclude=opts.exclude,
  101. include_fk=opts.include_fk,
  102. base_fields=base_fields,
  103. dict_cls=dict_cls,
  104. )
  105. )
  106. elif opts.model is not None:
  107. fields.update(
  108. converter.fields_for_model(
  109. opts.model,
  110. fields=opts.fields,
  111. exclude=opts.exclude,
  112. include_fk=opts.include_fk,
  113. include_relationships=opts.include_relationships,
  114. base_fields=base_fields,
  115. dict_cls=dict_cls,
  116. )
  117. )
  118. return fields
  119. class SQLAlchemySchema(
  120. LoadInstanceMixin.Schema, Schema, metaclass=SQLAlchemySchemaMeta
  121. ):
  122. """Schema for a SQLAlchemy model or table.
  123. Use together with `auto_field` to generate fields from columns.
  124. Example: ::
  125. from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
  126. from mymodels import User
  127. class UserSchema(SQLAlchemySchema):
  128. class Meta:
  129. model = User
  130. id = auto_field()
  131. created_at = auto_field(dump_only=True)
  132. name = auto_field()
  133. """
  134. OPTIONS_CLASS = SQLAlchemySchemaOpts
  135. class SQLAlchemyAutoSchema(SQLAlchemySchema, metaclass=SQLAlchemyAutoSchemaMeta):
  136. """Schema that automatically generates fields from the columns of
  137. a SQLAlchemy model or table.
  138. Example: ::
  139. from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, auto_field
  140. from mymodels import User
  141. class UserSchema(SQLAlchemyAutoSchema):
  142. class Meta:
  143. model = User
  144. # OR
  145. # table = User.__table__
  146. created_at = auto_field(dump_only=True)
  147. """
  148. OPTIONS_CLASS = SQLAlchemyAutoSchemaOpts
  149. def auto_field(
  150. column_name: str = None,
  151. *,
  152. model: DeclarativeMeta = None,
  153. table: sa.Table = None,
  154. **kwargs,
  155. ):
  156. """Mark a field to autogenerate from a model or table.
  157. :param column_name: Name of the column to generate the field from.
  158. If ``None``, matches the field name. If ``attribute`` is unspecified,
  159. ``attribute`` will be set to the same value as ``column_name``.
  160. :param model: Model to generate the field from.
  161. If ``None``, uses ``model`` specified on ``class Meta``.
  162. :param table: Table to generate the field from.
  163. If ``None``, uses ``table`` specified on ``class Meta``.
  164. :param kwargs: Field argument overrides.
  165. """
  166. if column_name is not None:
  167. kwargs.setdefault("attribute", column_name)
  168. return SQLAlchemyAutoField(
  169. column_name=column_name, model=model, table=table, field_kwargs=kwargs
  170. )