convert.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from typing import Any, Callable, Dict, List, Optional, Type
  2. from flask_appbuilder.models.sqla import Model
  3. from flask_appbuilder.models.sqla.interface import SQLAInterface
  4. from marshmallow import fields, Schema
  5. from marshmallow.fields import Field
  6. from marshmallow_sqlalchemy import field_for
  7. from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
  8. class TreeNode:
  9. def __init__(self, name: str) -> None:
  10. self.name = name
  11. self.children: List["TreeNode"] = []
  12. def __repr__(self) -> str:
  13. return f"{self.name}.{str(self.children)}"
  14. class Tree:
  15. """
  16. Simplistic one level Tree
  17. """
  18. def __init__(self) -> None:
  19. self.root = TreeNode("+")
  20. def add(self, name: str) -> None:
  21. node = TreeNode(name)
  22. self.root.children.append(node)
  23. def add_child(self, parent: str, name: str) -> None:
  24. node = TreeNode(name)
  25. for child in self.root.children:
  26. if child.name == parent:
  27. child.children.append(node)
  28. return
  29. root = TreeNode(parent)
  30. self.root.children.append(root)
  31. root.children.append(node)
  32. def __repr__(self) -> str:
  33. ret = ""
  34. for node in self.root.children:
  35. ret += str(node)
  36. return ret
  37. def columns2Tree(columns: List[str]) -> Tree:
  38. tree = Tree()
  39. for column in columns:
  40. if "." in column:
  41. parent, child = column.split(".")
  42. tree.add_child(parent, child)
  43. else:
  44. tree.add(column)
  45. return tree
  46. class BaseModel2SchemaConverter(object):
  47. def __init__(
  48. self,
  49. datamodel: SQLAInterface,
  50. validators_columns: Dict[str, Callable[[Any], Any]],
  51. ):
  52. """
  53. :param datamodel: SQLAInterface
  54. """
  55. self.datamodel = datamodel
  56. self.validators_columns = validators_columns
  57. def convert(
  58. self,
  59. columns: List[str],
  60. model: Optional[Type[Model]] = None,
  61. nested: bool = True,
  62. parent_schema_name: Optional[str] = None,
  63. ) -> SQLAlchemyAutoSchema:
  64. pass
  65. class Model2SchemaConverter(BaseModel2SchemaConverter):
  66. """
  67. Class that converts Models to marshmallow Schemas
  68. """
  69. def __init__(
  70. self,
  71. datamodel: SQLAInterface,
  72. validators_columns: Dict[str, Callable[[Any], Any]],
  73. ):
  74. """
  75. :param datamodel: SQLAInterface
  76. """
  77. super(Model2SchemaConverter, self).__init__(datamodel, validators_columns)
  78. @staticmethod
  79. def _debug_schema(schema: SQLAlchemyAutoSchema) -> None:
  80. for k, v in schema._declared_fields.items():
  81. print(k, v)
  82. def _meta_schema_factory(
  83. self,
  84. columns: List[str],
  85. model: Optional[Type[Model]],
  86. class_mixin: Type[Schema],
  87. parent_schema_name: Optional[str] = None,
  88. ) -> Type[SQLAlchemyAutoSchema]:
  89. """
  90. Creates ModelSchema marshmallow-sqlalchemy
  91. :param columns: a list of columns to mix
  92. :param model: Model
  93. :param class_mixin: a marshamallow Schema to mix
  94. :return: ModelSchema
  95. """
  96. _model = model
  97. _parent_schema_name = parent_schema_name
  98. if columns:
  99. class MetaSchema(SQLAlchemyAutoSchema, class_mixin): # type: ignore
  100. class Meta:
  101. model = _model
  102. fields = columns
  103. load_instance = True
  104. sqla_session = self.datamodel.session
  105. # The parent_schema_name is useful to humanize nested schema names
  106. # This name comes from ModelRestApi
  107. parent_schema_name = _parent_schema_name
  108. return MetaSchema
  109. class MetaSchema(SQLAlchemyAutoSchema, class_mixin): # type: ignore
  110. class Meta:
  111. model = _model
  112. load_instance = True
  113. sqla_session = self.datamodel.session
  114. # The parent_schema_name is useful to humanize nested schema names
  115. # This name comes from ModelRestApi
  116. parent_schema_name = _parent_schema_name
  117. return MetaSchema
  118. def _column2enum(self, datamodel: SQLAInterface, column: TreeNode) -> Field:
  119. required = not datamodel.is_nullable(column.name)
  120. sqla_column = datamodel.list_columns[column.name]
  121. # get SQLAlchemy column user info, we use it to get the marshmallow enum options
  122. column_info = sqla_column.info
  123. # TODO: Default should be False, but keeping this to True to keep compatibility
  124. # Turn this to False in the next major release
  125. by_value = column_info.get("marshmallow_by_value", True)
  126. # Get the original enum class from SQLAlchemy Enum field
  127. enum_class = sqla_column.type.enum_class
  128. if not enum_class:
  129. field = field_for(datamodel.obj, column.name)
  130. else:
  131. field = fields.Enum(enum_class, required=required, by_value=by_value)
  132. field.unique = datamodel.is_unique(column.name)
  133. return field
  134. def _column2relation(
  135. self,
  136. datamodel: SQLAInterface,
  137. column: TreeNode,
  138. nested: bool = False,
  139. parent_schema_name: Optional[str] = None,
  140. ) -> Field:
  141. if nested:
  142. required = not datamodel.is_nullable(column.name)
  143. nested_model = datamodel.get_related_model(column.name)
  144. lst = [item.name for item in column.children]
  145. nested_schema = self.convert(
  146. lst, nested_model, nested=False, parent_schema_name=parent_schema_name
  147. )
  148. if datamodel.is_relation_many_to_one(column.name):
  149. many = False
  150. elif datamodel.is_relation_many_to_many(column.name):
  151. many = True
  152. required = False
  153. elif datamodel.is_relation_one_to_many(column.name):
  154. many = True
  155. else:
  156. many = False
  157. field = fields.Nested(nested_schema, many=many, required=required)
  158. field.unique = datamodel.is_unique(column.name)
  159. return field
  160. # Handle bug on marshmallow-sqlalchemy
  161. # https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/163
  162. if datamodel.is_relation_many_to_many(
  163. column.name
  164. ) or datamodel.is_relation_one_to_many(column.name):
  165. required = datamodel.get_info(column.name).get("required", False)
  166. else:
  167. required = not datamodel.is_nullable(column.name)
  168. field = field_for(datamodel.obj, column.name)
  169. field.required = required
  170. field.unique = datamodel.is_unique(column.name)
  171. return field
  172. def _column2field(
  173. self,
  174. datamodel: SQLAInterface,
  175. column: TreeNode,
  176. nested: bool = True,
  177. parent_schema_name: Optional[str] = None,
  178. ) -> Field:
  179. """
  180. :param datamodel: SQLAInterface
  181. :param column: TreeNode column (childs are dotted columns)
  182. :param nested: Boolean if will create nested fields
  183. :return: Schema.field
  184. """
  185. # Handle relations
  186. if datamodel.is_relation(column.name):
  187. return self._column2relation(
  188. datamodel, column, nested=nested, parent_schema_name=parent_schema_name
  189. )
  190. # Handle Enums
  191. if datamodel.is_enum(column.name):
  192. return self._column2enum(datamodel, column)
  193. # is custom property method field?
  194. if hasattr(getattr(datamodel.obj, column.name), "fget"):
  195. return fields.Raw(dump_only=True)
  196. # its a model function
  197. if hasattr(getattr(datamodel.obj, column.name), "__call__"):
  198. return fields.Function(getattr(datamodel.obj, column.name), dump_only=True)
  199. # is a normal model field not a function?
  200. if not hasattr(getattr(datamodel.obj, column.name), "__call__"):
  201. field = field_for(datamodel.obj, column.name)
  202. field.unique = datamodel.is_unique(column.name)
  203. if column.name in self.validators_columns:
  204. if field.validate is None:
  205. field.validate = []
  206. field.validate.append(self.validators_columns[column.name])
  207. field.validators.append(self.validators_columns[column.name])
  208. return field
  209. def convert(
  210. self,
  211. columns: List[str],
  212. model: Optional[Type[Model]] = None,
  213. nested: bool = True,
  214. parent_schema_name: Optional[str] = None,
  215. ) -> SQLAlchemyAutoSchema:
  216. """
  217. Creates a Marshmallow ModelSchema class
  218. :param columns: List with columns to include, if empty converts all on model
  219. :param model: Override Model to convert
  220. :param nested: Generate relation with nested schemas
  221. :return: ModelSchema object
  222. """
  223. super(Model2SchemaConverter, self).convert(
  224. columns, model=model, nested=nested, parent_schema_name=parent_schema_name
  225. )
  226. class SchemaMixin:
  227. pass
  228. _model = model or self.datamodel.obj
  229. _datamodel = self.datamodel.__class__(_model)
  230. ma_sqla_fields_override = {}
  231. _columns = list()
  232. tree_columns = columns2Tree(columns)
  233. for column in tree_columns.root.children:
  234. # Get child model is column is dotted notation
  235. ma_sqla_fields_override[column.name] = self._column2field(
  236. _datamodel, column, nested, parent_schema_name=parent_schema_name
  237. )
  238. _columns.append(column.name)
  239. for k, v in ma_sqla_fields_override.items():
  240. setattr(SchemaMixin, k, v)
  241. return self._meta_schema_factory(
  242. _columns, _model, SchemaMixin, parent_schema_name=parent_schema_name
  243. )()