serde.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import dataclasses
  20. import enum
  21. import functools
  22. import logging
  23. import sys
  24. from fnmatch import fnmatch
  25. from importlib import import_module
  26. from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast
  27. import attr
  28. import re2
  29. import airflow.serialization.serializers
  30. from airflow.configuration import conf
  31. from airflow.stats import Stats
  32. from airflow.utils.module_loading import import_string, iter_namespace, qualname
  33. if TYPE_CHECKING:
  34. from types import ModuleType
  35. log = logging.getLogger(__name__)
  36. MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
  37. CLASSNAME = "__classname__"
  38. VERSION = "__version__"
  39. DATA = "__data__"
  40. SCHEMA_ID = "__id__"
  41. CACHE = "__cache__"
  42. OLD_TYPE = "__type"
  43. OLD_SOURCE = "__source"
  44. OLD_DATA = "__var"
  45. OLD_DICT = "dict"
  46. DEFAULT_VERSION = 0
  47. T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
  48. U = Union[bool, float, int, dict, list, str, tuple, set]
  49. S = Union[list, tuple, set]
  50. _serializers: dict[str, ModuleType] = {}
  51. _deserializers: dict[str, ModuleType] = {}
  52. _stringifiers: dict[str, ModuleType] = {}
  53. _extra_allowed: set[str] = set()
  54. _primitives = (int, bool, float, str)
  55. _builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
  56. def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
  57. """Encode an object so it can be understood by the deserializer."""
  58. return {CLASSNAME: cls, VERSION: version, DATA: data}
  59. def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
  60. classname = d[CLASSNAME]
  61. version = d[VERSION]
  62. if not isinstance(classname, str) or not isinstance(version, int):
  63. raise ValueError(f"cannot decode {d!r}")
  64. data = d.get(DATA)
  65. return classname, version, data
  66. def serialize(o: object, depth: int = 0) -> U | None:
  67. """
  68. Serialize an object into a representation consisting only built-in types.
  69. Primitives (int, float, bool, str) are returned as-is. Built-in collections
  70. are iterated over, where it is assumed that keys in a dict can be represented
  71. as str.
  72. Values that are not of a built-in type are serialized if a serializer is
  73. found for them. The order in which serializers are used is
  74. 1. A ``serialize`` function provided by the object.
  75. 2. A registered serializer in the namespace of ``airflow.serialization.serializers``
  76. 3. Annotations from attr or dataclass.
  77. Limitations: attr and dataclass objects can lose type information for nested objects
  78. as they do not store this when calling ``asdict``. This means that at deserialization values
  79. will be deserialized as a dict as opposed to reinstating the object. Provide
  80. your own serializer to work around this.
  81. :param o: The object to serialize.
  82. :param depth: Private tracker for nested serialization.
  83. :raise TypeError: A serializer cannot be found.
  84. :raise RecursionError: The object is too nested for the function to handle.
  85. :return: A representation of ``o`` that consists of only built-in types.
  86. """
  87. if depth == MAX_RECURSION_DEPTH:
  88. raise RecursionError("maximum recursion depth reached for serialization")
  89. # None remains None
  90. if o is None:
  91. return o
  92. # primitive types are returned as is
  93. if isinstance(o, _primitives):
  94. if isinstance(o, enum.Enum):
  95. return o.value
  96. return o
  97. if isinstance(o, list):
  98. return [serialize(d, depth + 1) for d in o]
  99. if isinstance(o, dict):
  100. if CLASSNAME in o or SCHEMA_ID in o:
  101. raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
  102. return {str(k): serialize(v, depth + 1) for k, v in o.items()}
  103. cls = type(o)
  104. qn = qualname(o)
  105. classname = None
  106. # Serialize namedtuple like tuples
  107. # We also override the classname returned by the builtin.py serializer. The classname
  108. # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple.
  109. if _is_namedtuple(o):
  110. qn = "builtins.tuple"
  111. classname = qn
  112. # if there is a builtin serializer available use that
  113. if qn in _serializers:
  114. data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
  115. if is_serialized:
  116. return encode(classname or serialized_classname, version, serialize(data, depth + 1))
  117. # custom serializers
  118. dct = {
  119. CLASSNAME: qn,
  120. VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
  121. }
  122. # object / class brings their own
  123. if hasattr(o, "serialize"):
  124. data = getattr(o, "serialize")()
  125. # if we end up with a structure, ensure its values are serialized
  126. if isinstance(data, dict):
  127. data = serialize(data, depth + 1)
  128. dct[DATA] = data
  129. return dct
  130. # pydantic models are recursive
  131. if _is_pydantic(cls):
  132. data = o.model_dump() # type: ignore[attr-defined]
  133. dct[DATA] = serialize(data, depth + 1)
  134. return dct
  135. # dataclasses
  136. if dataclasses.is_dataclass(cls):
  137. # fixme: unfortunately using asdict with nested dataclasses it looses information
  138. data = dataclasses.asdict(o) # type: ignore[call-overload]
  139. dct[DATA] = serialize(data, depth + 1)
  140. return dct
  141. # attr annotated
  142. if attr.has(cls):
  143. # Only include attributes which we can pass back to the classes constructor
  144. data = attr.asdict(cast(attr.AttrsInstance, o), recurse=False, filter=lambda a, v: a.init)
  145. dct[DATA] = serialize(data, depth + 1)
  146. return dct
  147. raise TypeError(f"cannot serialize object of type {cls}")
  148. def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
  149. """
  150. Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded.
  151. :param o: primitive to deserialize into an arbitrary object.
  152. :param full: if False it will return a stringified representation
  153. of an object and will not load any classes
  154. :param type_hint: if set it will be used to help determine what
  155. object to deserialize in. It does not override if another
  156. specification is found
  157. :return: object
  158. """
  159. if o is None:
  160. return o
  161. if isinstance(o, _primitives):
  162. return o
  163. # tuples, sets are included here for backwards compatibility
  164. if isinstance(o, _builtin_collections):
  165. col = [deserialize(d) for d in o]
  166. if isinstance(o, tuple):
  167. return tuple(col)
  168. if isinstance(o, set):
  169. return set(col)
  170. return col
  171. if not isinstance(o, dict):
  172. # if o is not a dict, then it's already deserialized
  173. # in this case we should return it as is
  174. return o
  175. o = _convert(o)
  176. # plain dict and no type hint
  177. if CLASSNAME not in o and not type_hint or VERSION not in o:
  178. return {str(k): deserialize(v, full) for k, v in o.items()}
  179. # custom deserialization starts here
  180. cls: Any
  181. version = 0
  182. value: Any = None
  183. classname = ""
  184. if type_hint:
  185. cls = type_hint
  186. classname = qualname(cls)
  187. version = 0 # type hinting always sets version to 0
  188. value = o
  189. if CLASSNAME in o and VERSION in o:
  190. classname, version, value = decode(o)
  191. if not classname:
  192. raise TypeError("classname cannot be empty")
  193. # only return string representation
  194. if not full:
  195. return _stringify(classname, version, value)
  196. if not _match(classname) and classname not in _extra_allowed:
  197. raise ImportError(
  198. f"{classname} was not found in allow list for deserialization imports. "
  199. f"To allow it, add it to allowed_deserialization_classes in the configuration"
  200. )
  201. cls = import_string(classname)
  202. # registered deserializer
  203. if classname in _deserializers:
  204. return _deserializers[classname].deserialize(classname, version, deserialize(value))
  205. # class has deserialization function
  206. if hasattr(cls, "deserialize"):
  207. return getattr(cls, "deserialize")(deserialize(value), version)
  208. # attr or dataclass or pydantic
  209. if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls):
  210. class_version = getattr(cls, "__version__", 0)
  211. if int(version) > class_version:
  212. raise TypeError(
  213. "serialized version of %s is newer than module version (%s > %s)",
  214. classname,
  215. version,
  216. class_version,
  217. )
  218. return cls(**deserialize(value))
  219. # no deserializer available
  220. raise TypeError(f"No deserializer found for {classname}")
  221. def _convert(old: dict) -> dict:
  222. """Convert an old style serialization to new style."""
  223. if OLD_TYPE in old and OLD_DATA in old:
  224. # Return old style dicts directly as they do not need wrapping
  225. if old[OLD_TYPE] == OLD_DICT:
  226. return old[OLD_DATA]
  227. else:
  228. return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]}
  229. return old
  230. def _match(classname: str) -> bool:
  231. """Check if the given classname matches a path pattern either using glob format or regexp format."""
  232. return _match_glob(classname) or _match_regexp(classname)
  233. @functools.lru_cache(maxsize=None)
  234. def _match_glob(classname: str):
  235. """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
  236. patterns = _get_patterns()
  237. return any(fnmatch(classname, p.pattern) for p in patterns)
  238. @functools.lru_cache(maxsize=None)
  239. def _match_regexp(classname: str):
  240. """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
  241. patterns = _get_regexp_patterns()
  242. return any(p.match(classname) is not None for p in patterns)
  243. def _stringify(classname: str, version: int, value: T | None) -> str:
  244. """
  245. Convert a previously serialized object in a somewhat human-readable format.
  246. This function is not designed to be exact, and will not extensively traverse
  247. the whole tree of an object.
  248. """
  249. if classname in _stringifiers:
  250. return _stringifiers[classname].stringify(classname, version, value)
  251. s = f"{classname}@version={version}("
  252. if isinstance(value, _primitives):
  253. s += f"{value}"
  254. elif isinstance(value, _builtin_collections):
  255. # deserialized values can be != str
  256. s += ",".join(str(deserialize(value, full=False)))
  257. elif isinstance(value, dict):
  258. s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items())
  259. s += ")"
  260. return s
  261. def _is_pydantic(cls: Any) -> bool:
  262. """
  263. Return True if the class is a pydantic model.
  264. Checking is done by attributes as it is significantly faster than
  265. using isinstance.
  266. """
  267. return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set")
  268. def _is_namedtuple(cls: Any) -> bool:
  269. """
  270. Return True if the class is a namedtuple.
  271. Checking is done by attributes as it is significantly faster than
  272. using isinstance.
  273. """
  274. return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults")
  275. def _register():
  276. """Register builtin serializers and deserializers for types that don't have any themselves."""
  277. _serializers.clear()
  278. _deserializers.clear()
  279. _stringifiers.clear()
  280. with Stats.timer("serde.load_serializers") as timer:
  281. for _, name, _ in iter_namespace(airflow.serialization.serializers):
  282. name = import_module(name)
  283. for s in getattr(name, "serializers", ()):
  284. if not isinstance(s, str):
  285. s = qualname(s)
  286. if s in _serializers and _serializers[s] != name:
  287. raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
  288. log.debug("registering %s for serialization", s)
  289. _serializers[s] = name
  290. for d in getattr(name, "deserializers", ()):
  291. if not isinstance(d, str):
  292. d = qualname(d)
  293. if d in _deserializers and _deserializers[d] != name:
  294. raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
  295. log.debug("registering %s for deserialization", d)
  296. _deserializers[d] = name
  297. _extra_allowed.add(d)
  298. for c in getattr(name, "stringifiers", ()):
  299. if not isinstance(c, str):
  300. c = qualname(c)
  301. if c in _deserializers and _deserializers[c] != name:
  302. raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}")
  303. log.debug("registering %s for stringifying", c)
  304. _stringifiers[c] = name
  305. log.debug("loading serializers took %.3f seconds", timer.duration)
  306. @functools.lru_cache(maxsize=None)
  307. def _get_patterns() -> list[Pattern]:
  308. return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]
  309. @functools.lru_cache(maxsize=None)
  310. def _get_regexp_patterns() -> list[Pattern]:
  311. return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]
  312. _register()