param.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. import contextlib
  19. import copy
  20. import datetime
  21. import json
  22. import logging
  23. import warnings
  24. from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView
  25. from pendulum.parsing import parse_iso8601
  26. from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
  27. from airflow.utils import timezone
  28. from airflow.utils.mixins import ResolveMixin
  29. from airflow.utils.types import NOTSET, ArgNotSet
  30. if TYPE_CHECKING:
  31. from airflow.models.dag import DAG
  32. from airflow.models.dagrun import DagRun
  33. from airflow.models.operator import Operator
  34. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  35. from airflow.utils.context import Context
  36. logger = logging.getLogger(__name__)
  37. class Param:
  38. """
  39. Class to hold the default value of a Param and rule set to do the validations.
  40. Without the rule set it always validates and returns the default value.
  41. :param default: The value this Param object holds
  42. :param description: Optional help text for the Param
  43. :param schema: The validation schema of the Param, if not given then all kwargs except
  44. default & description will form the schema
  45. """
  46. __version__: ClassVar[int] = 1
  47. CLASS_IDENTIFIER = "__class"
  48. def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
  49. if default is not NOTSET:
  50. self._warn_if_not_json(default)
  51. self.value = default
  52. self.description = description
  53. self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
  54. def __copy__(self) -> Param:
  55. return Param(self.value, self.description, schema=self.schema)
  56. @staticmethod
  57. def _warn_if_not_json(value):
  58. try:
  59. json.dumps(value)
  60. except Exception:
  61. warnings.warn(
  62. "The use of non-json-serializable params is deprecated and will be removed in "
  63. "a future release",
  64. RemovedInAirflow3Warning,
  65. stacklevel=1,
  66. )
  67. @staticmethod
  68. def _warn_if_not_rfc3339_dt(value):
  69. """Fallback to iso8601 datetime validation if rfc3339 failed."""
  70. try:
  71. iso8601_value = parse_iso8601(value)
  72. except Exception:
  73. return None
  74. if not isinstance(iso8601_value, datetime.datetime):
  75. return None
  76. warnings.warn(
  77. f"The use of non-RFC3339 datetime: {value!r} is deprecated "
  78. "and will be removed in a future release",
  79. RemovedInAirflow3Warning,
  80. stacklevel=1,
  81. )
  82. if timezone.is_naive(iso8601_value):
  83. warnings.warn(
  84. "The use naive datetime is deprecated and will be removed in a future release",
  85. RemovedInAirflow3Warning,
  86. stacklevel=1,
  87. )
  88. return value
  89. def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
  90. """
  91. Run the validations and returns the Param's final value.
  92. May raise ValueError on failed validations, or TypeError
  93. if no value is passed and no value already exists.
  94. We first check that value is json-serializable; if not, warn.
  95. In future release we will require the value to be json-serializable.
  96. :param value: The value to be updated for the Param
  97. :param suppress_exception: To raise an exception or not when the validations fails.
  98. If true and validations fails, the return value would be None.
  99. """
  100. import jsonschema
  101. from jsonschema import FormatChecker
  102. from jsonschema.exceptions import ValidationError
  103. if value is not NOTSET:
  104. self._warn_if_not_json(value)
  105. final_val = self.value if value is NOTSET else value
  106. if isinstance(final_val, ArgNotSet):
  107. if suppress_exception:
  108. return None
  109. raise ParamValidationError("No value passed and Param has no default value")
  110. try:
  111. jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
  112. except ValidationError as err:
  113. if err.schema.get("format") == "date-time":
  114. rfc3339_value = self._warn_if_not_rfc3339_dt(final_val)
  115. if rfc3339_value:
  116. self.value = rfc3339_value
  117. return rfc3339_value
  118. if suppress_exception:
  119. return None
  120. raise ParamValidationError(err) from None
  121. self.value = final_val
  122. return final_val
  123. def dump(self) -> dict:
  124. """Dump the Param as a dictionary."""
  125. out_dict: dict[str, str | None] = {
  126. self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
  127. }
  128. out_dict.update(self.__dict__)
  129. # Ensure that not set is translated to None
  130. if self.value is NOTSET:
  131. out_dict["value"] = None
  132. return out_dict
  133. @property
  134. def has_value(self) -> bool:
  135. return self.value is not NOTSET and self.value is not None
  136. def serialize(self) -> dict:
  137. return {"value": self.value, "description": self.description, "schema": self.schema}
  138. @staticmethod
  139. def deserialize(data: dict[str, Any], version: int) -> Param:
  140. if version > Param.__version__:
  141. raise TypeError("serialized version > class version")
  142. return Param(default=data["value"], description=data["description"], schema=data["schema"])
  143. class ParamsDict(MutableMapping[str, Any]):
  144. """
  145. Class to hold all params for dags or tasks.
  146. All the keys are strictly string and values are converted into Param's object
  147. if they are not already. This class is to replace param's dictionary implicitly
  148. and ideally not needed to be used directly.
  149. :param dict_obj: A dict or dict like object to init ParamsDict
  150. :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
  151. """
  152. __version__: ClassVar[int] = 1
  153. __slots__ = ["__dict", "suppress_exception"]
  154. def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False):
  155. params_dict: dict[str, Param] = {}
  156. dict_obj = dict_obj or {}
  157. for k, v in dict_obj.items():
  158. if not isinstance(v, Param):
  159. params_dict[k] = Param(v)
  160. else:
  161. params_dict[k] = v
  162. self.__dict = params_dict
  163. self.suppress_exception = suppress_exception
  164. def __bool__(self) -> bool:
  165. return bool(self.__dict)
  166. def __eq__(self, other: Any) -> bool:
  167. if isinstance(other, ParamsDict):
  168. return self.dump() == other.dump()
  169. if isinstance(other, dict):
  170. return self.dump() == other
  171. return NotImplemented
  172. def __copy__(self) -> ParamsDict:
  173. return ParamsDict(self.__dict, self.suppress_exception)
  174. def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
  175. return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
  176. def __contains__(self, o: object) -> bool:
  177. return o in self.__dict
  178. def __len__(self) -> int:
  179. return len(self.__dict)
  180. def __delitem__(self, v: str) -> None:
  181. del self.__dict[v]
  182. def __iter__(self):
  183. return iter(self.__dict)
  184. def __repr__(self):
  185. return repr(self.dump())
  186. def __setitem__(self, key: str, value: Any) -> None:
  187. """
  188. Override for dictionary's ``setitem`` method to ensure all values are of Param's type only.
  189. :param key: A key which needs to be inserted or updated in the dict
  190. :param value: A value which needs to be set against the key. It could be of any
  191. type but will be converted and stored as a Param object eventually.
  192. """
  193. if isinstance(value, Param):
  194. param = value
  195. elif key in self.__dict:
  196. param = self.__dict[key]
  197. try:
  198. param.resolve(value=value, suppress_exception=self.suppress_exception)
  199. except ParamValidationError as ve:
  200. raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
  201. else:
  202. # if the key isn't there already and if the value isn't of Param type create a new Param object
  203. param = Param(value)
  204. self.__dict[key] = param
  205. def __getitem__(self, key: str) -> Any:
  206. """
  207. Override for dictionary's ``getitem`` method to call the resolve method after fetching the key.
  208. :param key: The key to fetch
  209. """
  210. param = self.__dict[key]
  211. return param.resolve(suppress_exception=self.suppress_exception)
  212. def get_param(self, key: str) -> Param:
  213. """Get the internal :class:`.Param` object for this key."""
  214. return self.__dict[key]
  215. def items(self):
  216. return ItemsView(self.__dict)
  217. def values(self):
  218. return ValuesView(self.__dict)
  219. def update(self, *args, **kwargs) -> None:
  220. if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
  221. return super().update(args[0].__dict)
  222. super().update(*args, **kwargs)
  223. def dump(self) -> dict[str, Any]:
  224. """Dump the ParamsDict object as a dictionary, while suppressing exceptions."""
  225. return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
  226. def validate(self) -> dict[str, Any]:
  227. """Validate & returns all the Params object stored in the dictionary."""
  228. resolved_dict = {}
  229. try:
  230. for k, v in self.items():
  231. resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
  232. except ParamValidationError as ve:
  233. raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
  234. return resolved_dict
  235. def serialize(self) -> dict[str, Any]:
  236. return self.dump()
  237. @staticmethod
  238. def deserialize(data: dict, version: int) -> ParamsDict:
  239. if version > ParamsDict.__version__:
  240. raise TypeError("serialized version > class version")
  241. return ParamsDict(data)
  242. class DagParam(ResolveMixin):
  243. """
  244. DAG run parameter reference.
  245. This binds a simple Param object to a name within a DAG instance, so that it
  246. can be resolved during the runtime via the ``{{ context }}`` dictionary. The
  247. ideal use case of this class is to implicitly convert args passed to a
  248. method decorated by ``@dag``.
  249. It can be used to parameterize a DAG. You can overwrite its value by setting
  250. it on conf when you trigger your DagRun.
  251. This can also be used in templates by accessing ``{{ context.params }}``.
  252. **Example**:
  253. with DAG(...) as dag:
  254. EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
  255. :param current_dag: Dag being used for parameter.
  256. :param name: key value which is used to set the parameter
  257. :param default: Default value used if no parameter was set.
  258. """
  259. def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
  260. if default is not NOTSET:
  261. current_dag.params[name] = default
  262. self._name = name
  263. self._default = default
  264. def iter_references(self) -> Iterable[tuple[Operator, str]]:
  265. return ()
  266. def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
  267. """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
  268. with contextlib.suppress(KeyError):
  269. return context["dag_run"].conf[self._name]
  270. if self._default is not NOTSET:
  271. return self._default
  272. with contextlib.suppress(KeyError):
  273. return context["params"][self._name]
  274. raise AirflowException(f"No value could be resolved for parameter {self._name}")
  275. def process_params(
  276. dag: DAG,
  277. task: Operator,
  278. dag_run: DagRun | DagRunPydantic | None,
  279. *,
  280. suppress_exception: bool,
  281. ) -> dict[str, Any]:
  282. """Merge, validate params, and convert them into a simple dict."""
  283. from airflow.configuration import conf
  284. params = ParamsDict(suppress_exception=suppress_exception)
  285. with contextlib.suppress(AttributeError):
  286. params.update(dag.params)
  287. if task.params:
  288. params.update(task.params)
  289. if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
  290. logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
  291. params.update(dag_run.conf)
  292. return params.validate()