convert.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import inspect
  2. import functools
  3. import warnings
  4. import uuid
  5. import marshmallow as ma
  6. from marshmallow import validate, fields
  7. from packaging.version import Version
  8. from sqlalchemy.dialects import postgresql, mysql, mssql
  9. from sqlalchemy.orm import SynonymProperty
  10. import sqlalchemy as sa
  11. from .exceptions import ModelConversionError
  12. from .fields import Related, RelatedList
  13. _META_KWARGS_DEPRECATED = Version(ma.__version__) >= Version("3.10.0")
  14. def _is_field(value):
  15. return isinstance(value, type) and issubclass(value, fields.Field)
  16. def _base_column(column):
  17. """Unwrap proxied columns"""
  18. if column not in column.base_columns and len(column.base_columns) == 1:
  19. [base] = column.base_columns
  20. return base
  21. return column
  22. def _has_default(column):
  23. return (
  24. column.default is not None
  25. or column.server_default is not None
  26. or _is_auto_increment(column)
  27. )
  28. def _is_auto_increment(column):
  29. return column.table is not None and column is column.table._autoincrement_column
  30. def _postgres_array_factory(converter, data_type):
  31. return functools.partial(
  32. fields.List, converter._get_field_class_for_data_type(data_type.item_type)
  33. )
  34. def _set_meta_kwarg(field_kwargs, key, value):
  35. if _META_KWARGS_DEPRECATED:
  36. field_kwargs["metadata"][key] = value
  37. else:
  38. field_kwargs[key] = value
  39. def _field_update_kwargs(field_class, field_kwargs, kwargs):
  40. if not kwargs:
  41. return field_kwargs
  42. if isinstance(field_class, functools.partial):
  43. # Unwrap partials, assuming that they bind a Field to arguments
  44. field_class = field_class.func
  45. possible_field_keywords = {
  46. key
  47. for cls in inspect.getmro(field_class)
  48. for key, param in inspect.signature(cls).parameters.items()
  49. if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD
  50. or param.kind is inspect.Parameter.KEYWORD_ONLY
  51. }
  52. for k, v in kwargs.items():
  53. if k in possible_field_keywords:
  54. field_kwargs[k] = v
  55. else:
  56. _set_meta_kwarg(field_kwargs, k, v)
  57. return field_kwargs
  58. class ModelConverter:
  59. """Class that converts a SQLAlchemy model into a dictionary of corresponding
  60. marshmallow `Fields <marshmallow.fields.Field>`.
  61. """
  62. SQLA_TYPE_MAPPING = {
  63. sa.Enum: fields.Field,
  64. sa.JSON: fields.Raw,
  65. postgresql.BIT: fields.Integer,
  66. postgresql.OID: fields.Integer,
  67. postgresql.UUID: fields.UUID,
  68. postgresql.MACADDR: fields.String,
  69. postgresql.INET: fields.String,
  70. postgresql.CIDR: fields.String,
  71. postgresql.JSON: fields.Raw,
  72. postgresql.JSONB: fields.Raw,
  73. postgresql.HSTORE: fields.Raw,
  74. postgresql.ARRAY: _postgres_array_factory,
  75. postgresql.MONEY: fields.Decimal,
  76. postgresql.DATE: fields.Date,
  77. postgresql.TIME: fields.Time,
  78. mysql.BIT: fields.Integer,
  79. mysql.YEAR: fields.Integer,
  80. mysql.SET: fields.List,
  81. mysql.ENUM: fields.Field,
  82. mysql.INTEGER: fields.Integer,
  83. mysql.DATETIME: fields.DateTime,
  84. mssql.BIT: fields.Integer,
  85. mssql.UNIQUEIDENTIFIER: fields.UUID,
  86. }
  87. DIRECTION_MAPPING = {"MANYTOONE": False, "MANYTOMANY": True, "ONETOMANY": True}
  88. def __init__(self, schema_cls=None):
  89. self.schema_cls = schema_cls
  90. @property
  91. def type_mapping(self):
  92. if self.schema_cls:
  93. return self.schema_cls.TYPE_MAPPING
  94. else:
  95. return ma.Schema.TYPE_MAPPING
  96. def fields_for_model(
  97. self,
  98. model,
  99. *,
  100. include_fk=False,
  101. include_relationships=False,
  102. fields=None,
  103. exclude=None,
  104. base_fields=None,
  105. dict_cls=dict,
  106. ):
  107. result = dict_cls()
  108. base_fields = base_fields or {}
  109. for prop in model.__mapper__.iterate_properties:
  110. key = self._get_field_name(prop)
  111. if self._should_exclude_field(prop, fields=fields, exclude=exclude):
  112. # Allow marshmallow to validate and exclude the field key.
  113. result[key] = None
  114. continue
  115. if isinstance(prop, SynonymProperty):
  116. continue
  117. if hasattr(prop, "columns"):
  118. if not include_fk:
  119. # Only skip a column if there is no overriden column
  120. # which does not have a Foreign Key.
  121. for column in prop.columns:
  122. if not column.foreign_keys:
  123. break
  124. else:
  125. continue
  126. if not include_relationships and hasattr(prop, "direction"):
  127. continue
  128. field = base_fields.get(key) or self.property2field(prop)
  129. if field:
  130. result[key] = field
  131. return result
  132. def fields_for_table(
  133. self,
  134. table,
  135. *,
  136. include_fk=False,
  137. fields=None,
  138. exclude=None,
  139. base_fields=None,
  140. dict_cls=dict,
  141. ):
  142. result = dict_cls()
  143. base_fields = base_fields or {}
  144. for column in table.columns:
  145. key = self._get_field_name(column)
  146. if self._should_exclude_field(column, fields=fields, exclude=exclude):
  147. # Allow marshmallow to validate and exclude the field key.
  148. result[key] = None
  149. continue
  150. if not include_fk and column.foreign_keys:
  151. continue
  152. # Overridden fields are specified relative to key generated by
  153. # self._get_key_for_column(...), rather than keys in source model
  154. field = base_fields.get(key) or self.column2field(column)
  155. if field:
  156. result[key] = field
  157. return result
  158. def property2field(self, prop, *, instance=True, field_class=None, **kwargs):
  159. # handle synonyms
  160. # Attribute renamed "_proxied_object" in 1.4
  161. for attr in ("_proxied_property", "_proxied_object"):
  162. proxied_obj = getattr(prop, attr, None)
  163. if proxied_obj is not None:
  164. prop = proxied_obj
  165. field_class = field_class or self._get_field_class_for_property(prop)
  166. if not instance:
  167. return field_class
  168. field_kwargs = self._get_field_kwargs_for_property(prop)
  169. _field_update_kwargs(field_class, field_kwargs, kwargs)
  170. ret = field_class(**field_kwargs)
  171. if (
  172. hasattr(prop, "direction")
  173. and self.DIRECTION_MAPPING[prop.direction.name]
  174. and prop.uselist is True
  175. ):
  176. related_list_kwargs = _field_update_kwargs(
  177. RelatedList, self.get_base_kwargs(), kwargs
  178. )
  179. ret = RelatedList(ret, **related_list_kwargs)
  180. return ret
  181. def column2field(self, column, *, instance=True, **kwargs):
  182. field_class = self._get_field_class_for_column(column)
  183. if not instance:
  184. return field_class
  185. field_kwargs = self.get_base_kwargs()
  186. self._add_column_kwargs(field_kwargs, column)
  187. _field_update_kwargs(field_class, field_kwargs, kwargs)
  188. return field_class(**field_kwargs)
  189. def field_for(self, model, property_name, **kwargs):
  190. target_model = model
  191. prop_name = property_name
  192. attr = getattr(model, property_name)
  193. remote_with_local_multiplicity = False
  194. if hasattr(attr, "remote_attr"):
  195. target_model = attr.target_class
  196. prop_name = attr.value_attr
  197. remote_with_local_multiplicity = attr.local_attr.prop.uselist
  198. prop = target_model.__mapper__.get_property(prop_name)
  199. converted_prop = self.property2field(prop, **kwargs)
  200. if remote_with_local_multiplicity:
  201. related_list_kwargs = _field_update_kwargs(
  202. RelatedList, self.get_base_kwargs(), kwargs
  203. )
  204. return RelatedList(converted_prop, **related_list_kwargs)
  205. else:
  206. return converted_prop
  207. def _get_field_name(self, prop_or_column):
  208. return prop_or_column.key
  209. def _get_field_class_for_column(self, column):
  210. return self._get_field_class_for_data_type(column.type)
  211. def _get_field_class_for_data_type(self, data_type):
  212. field_cls = None
  213. types = inspect.getmro(type(data_type))
  214. # First search for a field class from self.SQLA_TYPE_MAPPING
  215. for col_type in types:
  216. if col_type in self.SQLA_TYPE_MAPPING:
  217. field_cls = self.SQLA_TYPE_MAPPING[col_type]
  218. if callable(field_cls) and not _is_field(field_cls):
  219. field_cls = field_cls(self, data_type)
  220. break
  221. else:
  222. # Try to find a field class based on the column's python_type
  223. try:
  224. python_type = data_type.python_type
  225. except NotImplementedError:
  226. python_type = None
  227. if python_type in self.type_mapping:
  228. field_cls = self.type_mapping[python_type]
  229. else:
  230. if hasattr(data_type, "impl"):
  231. return self._get_field_class_for_data_type(data_type.impl)
  232. raise ModelConversionError(
  233. f"Could not find field column of type {types[0]}."
  234. )
  235. return field_cls
  236. def _get_field_class_for_property(self, prop):
  237. if hasattr(prop, "direction"):
  238. field_cls = Related
  239. else:
  240. column = _base_column(prop.columns[0])
  241. field_cls = self._get_field_class_for_column(column)
  242. return field_cls
  243. def _merge_validators(self, defaults, new):
  244. new_classes = [validator.__class__ for validator in new]
  245. return [
  246. validator
  247. for validator in defaults
  248. if validator.__class__ not in new_classes
  249. ] + new
  250. def _get_field_kwargs_for_property(self, prop):
  251. kwargs = self.get_base_kwargs()
  252. if hasattr(prop, "columns"):
  253. column = _base_column(prop.columns[0])
  254. self._add_column_kwargs(kwargs, column)
  255. prop = column
  256. if hasattr(prop, "direction"): # Relationship property
  257. self._add_relationship_kwargs(kwargs, prop)
  258. if getattr(prop, "doc", None): # Useful for documentation generation
  259. _set_meta_kwarg(kwargs, "description", prop.doc)
  260. info = getattr(prop, "info", dict())
  261. overrides = info.get("marshmallow")
  262. if overrides is not None:
  263. warnings.warn(
  264. 'Passing `info={"marshmallow": ...}` is deprecated. '
  265. "Use `SQLAlchemySchema` and `auto_field` instead.",
  266. DeprecationWarning,
  267. )
  268. validate = overrides.pop("validate", [])
  269. kwargs["validate"] = self._merge_validators(
  270. kwargs["validate"], validate
  271. ) # Ensure we do not override the generated validators.
  272. kwargs.update(overrides) # Override other kwargs.
  273. return kwargs
  274. def _add_column_kwargs(self, kwargs, column):
  275. """Add keyword arguments to kwargs (in-place) based on the passed in
  276. `Column <sqlalchemy.schema.Column>`.
  277. """
  278. if hasattr(column, "nullable"):
  279. if column.nullable:
  280. kwargs["allow_none"] = True
  281. kwargs["required"] = not column.nullable and not _has_default(column)
  282. # If there is no nullable attribute, we are dealing with a property
  283. # that does not derive from the Column class. Mark as dump_only.
  284. else:
  285. kwargs["dump_only"] = True
  286. if hasattr(column.type, "enums") and not kwargs.get("dump_only"):
  287. kwargs["validate"].append(validate.OneOf(choices=column.type.enums))
  288. # Add a length validator if a max length is set on the column
  289. # Skip UUID columns
  290. # (see https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/54)
  291. if hasattr(column.type, "length") and not kwargs.get("dump_only"):
  292. column_length = column.type.length
  293. if column_length is not None:
  294. try:
  295. python_type = column.type.python_type
  296. except (AttributeError, NotImplementedError):
  297. python_type = None
  298. if not python_type or not issubclass(python_type, uuid.UUID):
  299. kwargs["validate"].append(validate.Length(max=column_length))
  300. if getattr(column.type, "asdecimal", False):
  301. kwargs["places"] = getattr(column.type, "scale", None)
  302. def _add_relationship_kwargs(self, kwargs, prop):
  303. """Add keyword arguments to kwargs (in-place) based on the passed in
  304. relationship `Property`.
  305. """
  306. nullable = True
  307. for pair in prop.local_remote_pairs:
  308. if not pair[0].nullable:
  309. if prop.uselist is True:
  310. nullable = False
  311. break
  312. kwargs.update({"allow_none": nullable, "required": not nullable})
  313. def _should_exclude_field(self, column, fields=None, exclude=None):
  314. key = self._get_field_name(column)
  315. if fields and key not in fields:
  316. return True
  317. if exclude and key in exclude:
  318. return True
  319. return False
  320. def get_base_kwargs(self):
  321. kwargs = {"validate": []}
  322. if _META_KWARGS_DEPRECATED:
  323. kwargs["metadata"] = {}
  324. return kwargs
  325. default_converter = ModelConverter()
  326. fields_for_model = default_converter.fields_for_model
  327. """Generate a dict of field_name: `marshmallow.fields.Field` pairs for the given model.
  328. Note: SynonymProperties are ignored. Use an explicit field if you want to include a synonym.
  329. :param model: The SQLAlchemy model
  330. :param bool include_fk: Whether to include foreign key fields in the output.
  331. :param bool include_relationships: Whether to include relationships fields in the output.
  332. :return: dict of field_name: Field instance pairs
  333. """
  334. property2field = default_converter.property2field
  335. """Convert a SQLAlchemy `Property` to a field instance or class.
  336. :param Property prop: SQLAlchemy Property.
  337. :param bool instance: If `True`, return `Field` instance, computing relevant kwargs
  338. from the given property. If `False`, return the `Field` class.
  339. :param kwargs: Additional keyword arguments to pass to the field constructor.
  340. :return: A `marshmallow.fields.Field` class or instance.
  341. """
  342. column2field = default_converter.column2field
  343. """Convert a SQLAlchemy `Column <sqlalchemy.schema.Column>` to a field instance or class.
  344. :param sqlalchemy.schema.Column column: SQLAlchemy Column.
  345. :param bool instance: If `True`, return `Field` instance, computing relevant kwargs
  346. from the given property. If `False`, return the `Field` class.
  347. :return: A `marshmallow.fields.Field` class or instance.
  348. """
  349. field_for = default_converter.field_for
  350. """Convert a property for a mapped SQLAlchemy class to a marshmallow `Field`.
  351. Example: ::
  352. date_created = field_for(Author, 'date_created', dump_only=True)
  353. author = field_for(Book, 'author')
  354. :param type model: A SQLAlchemy mapped class.
  355. :param str property_name: The name of the property to convert.
  356. :param kwargs: Extra keyword arguments to pass to `property2field`
  357. :return: A `marshmallow.fields.Field` class or instance.
  358. """