openapi.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. """Utilities for generating OpenAPI Specification (fka Swagger) entities from
  2. marshmallow :class:`Schemas <marshmallow.Schema>` and :class:`Fields <marshmallow.fields.Field>`.
  3. .. warning::
  4. This module is treated as private API.
  5. Users should not need to use this module directly.
  6. """
  7. from __future__ import annotations
  8. import typing
  9. import marshmallow
  10. import marshmallow.exceptions
  11. from marshmallow.utils import is_collection
  12. from packaging.version import Version
  13. from apispec import APISpec
  14. from apispec.exceptions import APISpecError
  15. from .common import (
  16. get_fields,
  17. get_unique_schema_name,
  18. make_schema_key,
  19. resolve_schema_instance,
  20. )
  21. from .field_converter import FieldConverterMixin
  22. __location_map__ = {
  23. "match_info": "path",
  24. "query": "query",
  25. "querystring": "query",
  26. "json": "body",
  27. "headers": "header",
  28. "cookies": "cookie",
  29. "form": "formData",
  30. "files": "formData",
  31. }
  32. class OpenAPIConverter(FieldConverterMixin):
  33. """Adds methods for generating OpenAPI specification from marshmallow schemas and fields.
  34. :param Version|str openapi_version: The OpenAPI version to use.
  35. Should be in the form '2.x' or '3.x.x' to comply with the OpenAPI standard.
  36. :param callable schema_name_resolver: Callable to generate the schema definition name.
  37. Receives the `Schema` class and returns the name to be used in refs within
  38. the generated spec. When working with circular referencing this function
  39. must must not return `None` for schemas in a circular reference chain.
  40. :param APISpec spec: An initialized spec. Nested schemas will be added to the spec
  41. """
  42. def __init__(
  43. self,
  44. openapi_version: Version | str,
  45. schema_name_resolver,
  46. spec: APISpec,
  47. ) -> None:
  48. self.openapi_version = (
  49. Version(openapi_version)
  50. if isinstance(openapi_version, str)
  51. else openapi_version
  52. )
  53. self.schema_name_resolver = schema_name_resolver
  54. self.spec = spec
  55. self.init_attribute_functions()
  56. self.init_parameter_attribute_functions()
  57. # Schema references
  58. self.refs: dict = {}
  59. def init_parameter_attribute_functions(self) -> None:
  60. self.parameter_attribute_functions = [
  61. self.field2required,
  62. self.list2param,
  63. ]
  64. def add_parameter_attribute_function(self, func) -> None:
  65. """Method to add a field parameter function to the list of field
  66. parameter functions that will be called on a field to convert it to a
  67. field parameter.
  68. :param func func: the field parameter function to add
  69. The attribute function will be bound to the
  70. `OpenAPIConverter <apispec.ext.marshmallow.openapi.OpenAPIConverter>`
  71. instance.
  72. It will be called for each field in a schema with
  73. `self <apispec.ext.marshmallow.openapi.OpenAPIConverter>` and a
  74. `field <marshmallow.fields.Field>` instance
  75. positional arguments and `ret <dict>` keyword argument.
  76. May mutate `ret`.
  77. User added field parameter functions will be called after all built-in
  78. field parameter functions in the order they were added.
  79. """
  80. bound_func = func.__get__(self)
  81. setattr(self, func.__name__, bound_func)
  82. self.parameter_attribute_functions.append(bound_func)
  83. def resolve_nested_schema(self, schema):
  84. """Return the OpenAPI representation of a marshmallow Schema.
  85. Adds the schema to the spec if it isn't already present.
  86. Typically will return a dictionary with the reference to the schema's
  87. path in the spec unless the `schema_name_resolver` returns `None`, in
  88. which case the returned dictionary will contain a JSON Schema Object
  89. representation of the schema.
  90. :param schema: schema to add to the spec
  91. """
  92. try:
  93. schema_instance = resolve_schema_instance(schema)
  94. # If schema is a string and is not found in registry,
  95. # assume it is a schema reference
  96. except marshmallow.exceptions.RegistryError:
  97. return schema
  98. schema_key = make_schema_key(schema_instance)
  99. if schema_key not in self.refs:
  100. name = self.schema_name_resolver(schema)
  101. if not name:
  102. try:
  103. json_schema = self.schema2jsonschema(schema_instance)
  104. except RuntimeError as exc:
  105. raise APISpecError(
  106. f"Name resolver returned None for schema {schema} which is "
  107. "part of a chain of circular referencing schemas. Please"
  108. " ensure that the schema_name_resolver passed to"
  109. " MarshmallowPlugin returns a string for all circular"
  110. " referencing schemas."
  111. ) from exc
  112. if getattr(schema, "many", False):
  113. return {"type": "array", "items": json_schema}
  114. return json_schema
  115. name = get_unique_schema_name(self.spec.components, name)
  116. self.spec.components.schema(name, schema=schema)
  117. return self.get_ref_dict(schema_instance)
  118. def schema2parameters(
  119. self,
  120. schema,
  121. *,
  122. location,
  123. name: str = "body",
  124. required: bool = False,
  125. description: str | None = None,
  126. ):
  127. """Return an array of OpenAPI parameters given a given marshmallow
  128. :class:`Schema <marshmallow.Schema>`. If `location` is "body", then return an array
  129. of a single parameter; else return an array of a parameter for each included field in
  130. the :class:`Schema <marshmallow.Schema>`.
  131. In OpenAPI 3, only "query", "header", "path" or "cookie" are allowed for the location
  132. of parameters. "requestBody" is used when fields are in the body.
  133. https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
  134. """
  135. location = __location_map__.get(location, location)
  136. # OAS 2 body parameter
  137. if location == "body":
  138. param = {
  139. "in": location,
  140. "required": required,
  141. "name": name,
  142. "schema": self.resolve_nested_schema(schema),
  143. }
  144. if description:
  145. param["description"] = description
  146. return [param]
  147. assert not getattr(
  148. schema, "many", False
  149. ), "Schemas with many=True are only supported for 'json' location (aka 'in: body')"
  150. fields = get_fields(schema, exclude_dump_only=True)
  151. return [
  152. self._field2parameter(
  153. field_obj,
  154. name=field_obj.data_key or field_name,
  155. location=location,
  156. )
  157. for field_name, field_obj in fields.items()
  158. ]
  159. def _field2parameter(
  160. self, field: marshmallow.fields.Field, *, name: str, location: str
  161. ) -> dict:
  162. """Return an OpenAPI parameter as a `dict`, given a marshmallow
  163. :class:`Field <marshmallow.Field>`.
  164. https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
  165. """
  166. ret: dict = {"in": location, "name": name}
  167. prop = self.field2property(field)
  168. if self.openapi_version.major < 3:
  169. ret.update(prop)
  170. else:
  171. if "description" in prop:
  172. ret["description"] = prop.pop("description")
  173. if "deprecated" in prop:
  174. ret["deprecated"] = prop.pop("deprecated")
  175. ret["schema"] = prop
  176. for param_attr_func in self.parameter_attribute_functions:
  177. ret.update(param_attr_func(field, ret=ret))
  178. return ret
  179. def field2required(
  180. self, field: marshmallow.fields.Field, **kwargs: typing.Any
  181. ) -> dict:
  182. """Return the dictionary of OpenAPI parameter attributes for a required field.
  183. :param Field field: A marshmallow field.
  184. :rtype: dict
  185. """
  186. ret = {}
  187. partial = getattr(field.parent, "partial", False)
  188. ret["required"] = field.required and (
  189. not partial or (is_collection(partial) and field.name not in partial) # type:ignore
  190. )
  191. return ret
  192. def list2param(self, field: marshmallow.fields.Field, **kwargs: typing.Any) -> dict:
  193. """Return a dictionary of parameter properties from
  194. :class:`List <marshmallow.fields.List` fields.
  195. :param Field field: A marshmallow field.
  196. :rtype: dict
  197. """
  198. ret: dict = {}
  199. if isinstance(field, marshmallow.fields.List):
  200. if self.openapi_version.major < 3:
  201. ret["collectionFormat"] = "multi"
  202. else:
  203. ret["explode"] = True
  204. ret["style"] = "form"
  205. return ret
  206. def schema2jsonschema(self, schema):
  207. """Return the JSON Schema Object for a given marshmallow
  208. :class:`Schema <marshmallow.Schema>` instance. Schema may optionally
  209. provide the ``title`` and ``description`` class Meta options.
  210. https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schemaObject
  211. :param Schema schema: A marshmallow Schema instance
  212. :rtype: dict, a JSON Schema Object
  213. """
  214. fields = get_fields(schema)
  215. Meta = getattr(schema, "Meta", None)
  216. partial = getattr(schema, "partial", None)
  217. jsonschema = self.fields2jsonschema(fields, partial=partial)
  218. if hasattr(Meta, "title"):
  219. jsonschema["title"] = Meta.title
  220. if hasattr(Meta, "description"):
  221. jsonschema["description"] = Meta.description
  222. if hasattr(Meta, "unknown") and Meta.unknown != marshmallow.EXCLUDE:
  223. jsonschema["additionalProperties"] = Meta.unknown == marshmallow.INCLUDE
  224. return jsonschema
  225. def fields2jsonschema(self, fields, *, partial=None):
  226. """Return the JSON Schema Object given a mapping between field names and
  227. :class:`Field <marshmallow.Field>` objects.
  228. :param dict fields: A dictionary of field name field object pairs
  229. :param bool|tuple partial: Whether to override a field's required flag.
  230. If `True` no fields will be set as required. If an iterable fields
  231. in the iterable will not be marked as required.
  232. :rtype: dict, a JSON Schema Object
  233. """
  234. jsonschema = {"type": "object", "properties": {}}
  235. for field_name, field_obj in fields.items():
  236. observed_field_name = field_obj.data_key or field_name
  237. prop = self.field2property(field_obj)
  238. jsonschema["properties"][observed_field_name] = prop
  239. if field_obj.required:
  240. if not partial or (
  241. is_collection(partial) and field_name not in partial
  242. ):
  243. jsonschema.setdefault("required", []).append(observed_field_name)
  244. if "required" in jsonschema:
  245. jsonschema["required"].sort()
  246. return jsonschema
  247. def get_ref_dict(self, schema):
  248. """Method to create a dictionary containing a JSON reference to the
  249. schema in the spec
  250. """
  251. schema_key = make_schema_key(schema)
  252. ref_schema = self.spec.components.get_ref("schema", self.refs[schema_key])
  253. if getattr(schema, "many", False):
  254. return {"type": "array", "items": ref_schema}
  255. return ref_schema