common.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """Utilities to get schema instances/classes"""
  2. from __future__ import annotations
  3. import copy
  4. import warnings
  5. import marshmallow
  6. import marshmallow.class_registry
  7. from marshmallow import fields
  8. from apispec.core import Components
  9. MODIFIERS = ["only", "exclude", "load_only", "dump_only", "partial"]
  10. def resolve_schema_instance(
  11. schema: type[marshmallow.Schema] | marshmallow.Schema | str,
  12. ) -> marshmallow.Schema:
  13. """Return schema instance for given schema (instance or class).
  14. :param type|Schema|str schema: instance, class or class name of marshmallow.Schema
  15. :return: schema instance of given schema (instance or class)
  16. """
  17. if isinstance(schema, type) and issubclass(schema, marshmallow.Schema):
  18. return schema()
  19. if isinstance(schema, marshmallow.Schema):
  20. return schema
  21. return marshmallow.class_registry.get_class(schema)()
  22. def resolve_schema_cls(
  23. schema: type[marshmallow.Schema] | str | marshmallow.Schema,
  24. ) -> type[marshmallow.Schema] | list[type[marshmallow.Schema]]:
  25. """Return schema class for given schema (instance or class).
  26. :param type|Schema|str: instance, class or class name of marshmallow.Schema
  27. :return: schema class of given schema (instance or class)
  28. """
  29. if isinstance(schema, type) and issubclass(schema, marshmallow.Schema):
  30. return schema
  31. if isinstance(schema, marshmallow.Schema):
  32. return type(schema)
  33. return marshmallow.class_registry.get_class(str(schema))
  34. def get_fields(
  35. schema: type[marshmallow.Schema] | marshmallow.Schema,
  36. *,
  37. exclude_dump_only: bool = False,
  38. ) -> dict[str, fields.Field]:
  39. """Return fields from schema.
  40. :param Schema schema: A marshmallow Schema instance or a class object
  41. :param bool exclude_dump_only: whether to filter fields in Meta.dump_only
  42. :rtype: dict, of field name field object pairs
  43. """
  44. if isinstance(schema, marshmallow.Schema):
  45. fields = schema.fields
  46. elif isinstance(schema, type) and issubclass(schema, marshmallow.Schema):
  47. fields = copy.deepcopy(schema._declared_fields)
  48. else:
  49. raise ValueError(f"{schema!r} is neither a Schema class nor a Schema instance.")
  50. Meta = getattr(schema, "Meta", None)
  51. warn_if_fields_defined_in_meta(fields, Meta)
  52. return filter_excluded_fields(fields, Meta, exclude_dump_only=exclude_dump_only)
  53. def warn_if_fields_defined_in_meta(fields: dict[str, fields.Field], Meta):
  54. """Warns user that fields defined in Meta.fields or Meta.additional will be ignored.
  55. :param dict fields: A dictionary of fields name field object pairs
  56. :param Meta: the schema's Meta class
  57. """
  58. if getattr(Meta, "fields", None) or getattr(Meta, "additional", None):
  59. declared_fields = set(fields.keys())
  60. if (
  61. set(getattr(Meta, "fields", set())) > declared_fields
  62. or set(getattr(Meta, "additional", set())) > declared_fields
  63. ):
  64. warnings.warn(
  65. "Only explicitly-declared fields will be included in the Schema Object. "
  66. "Fields defined in Meta.fields or Meta.additional are ignored.",
  67. UserWarning,
  68. stacklevel=2,
  69. )
  70. def filter_excluded_fields(
  71. fields: dict[str, fields.Field], Meta, *, exclude_dump_only: bool
  72. ) -> dict[str, fields.Field]:
  73. """Filter fields that should be ignored in the OpenAPI spec.
  74. :param dict fields: A dictionary of fields name field object pairs
  75. :param Meta: the schema's Meta class
  76. :param bool exclude_dump_only: whether to filter dump_only fields
  77. """
  78. exclude = list(getattr(Meta, "exclude", []))
  79. if exclude_dump_only:
  80. exclude.extend(getattr(Meta, "dump_only", []))
  81. filtered_fields = {
  82. key: value
  83. for key, value in fields.items()
  84. if key not in exclude and not (exclude_dump_only and value.dump_only)
  85. }
  86. return filtered_fields
  87. def make_schema_key(schema: marshmallow.Schema) -> tuple[type[marshmallow.Schema], ...]:
  88. if not isinstance(schema, marshmallow.Schema):
  89. raise TypeError("can only make a schema key based on a Schema instance.")
  90. modifiers = []
  91. for modifier in MODIFIERS:
  92. attribute = getattr(schema, modifier)
  93. try:
  94. # Hashable (string, tuple)
  95. hash(attribute)
  96. except TypeError:
  97. # Unhashable iterable (list, set)
  98. attribute = frozenset(attribute)
  99. modifiers.append(attribute)
  100. return tuple([schema.__class__, *modifiers])
  101. def get_unique_schema_name(components: Components, name: str, counter: int = 0) -> str:
  102. """Function to generate a unique name based on the provided name and names
  103. already in the spec. Will append a number to the name to make it unique if
  104. the name is already in the spec.
  105. :param Components components: instance of the components of the spec
  106. :param string name: the name to use as a basis for the unique name
  107. :param int counter: the counter of the number of recursions
  108. :return: the unique name
  109. """
  110. if name not in components.schemas:
  111. return name
  112. if not counter: # first time through recursion
  113. warnings.warn(
  114. f"Multiple schemas resolved to the name {name}. The name has been modified. "
  115. "Either manually add each of the schemas with a different name or "
  116. "provide a custom schema_name_resolver.",
  117. UserWarning,
  118. stacklevel=2,
  119. )
  120. else: # subsequent recursions
  121. name = name[: -len(str(counter))]
  122. counter += 1
  123. return get_unique_schema_name(components, name + str(counter), counter)