task_command.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Task sub-commands."""
  19. from __future__ import annotations
  20. import functools
  21. import importlib
  22. import json
  23. import logging
  24. import os
  25. import sys
  26. import textwrap
  27. from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
  28. from typing import TYPE_CHECKING, Generator, Protocol, Union, cast
  29. import pendulum
  30. from pendulum.parsing.exceptions import ParserError
  31. from sqlalchemy import select
  32. from airflow import settings
  33. from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
  34. from airflow.cli.simple_table import AirflowConsole
  35. from airflow.configuration import conf
  36. from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, TaskInstanceNotFound
  37. from airflow.executors.executor_loader import ExecutorLoader
  38. from airflow.jobs.job import Job, run_job
  39. from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
  40. from airflow.listeners.listener import get_listener_manager
  41. from airflow.models import DagPickle, TaskInstance
  42. from airflow.models.dag import DAG, _run_inline_trigger
  43. from airflow.models.dagrun import DagRun
  44. from airflow.models.param import ParamsDict
  45. from airflow.models.taskinstance import TaskReturnCode
  46. from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
  47. from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
  48. from airflow.ti_deps.dep_context import DepContext
  49. from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
  50. from airflow.typing_compat import Literal
  51. from airflow.utils import cli as cli_utils
  52. from airflow.utils.cli import (
  53. get_dag,
  54. get_dag_by_file_location,
  55. get_dag_by_pickle,
  56. get_dags,
  57. should_ignore_depends_on_past,
  58. suppress_logs_and_warning,
  59. )
  60. from airflow.utils.dates import timezone
  61. from airflow.utils.log.file_task_handler import _set_task_deferred_context_var
  62. from airflow.utils.log.logging_mixin import StreamLogWriter
  63. from airflow.utils.log.secrets_masker import RedactedIO
  64. from airflow.utils.net import get_hostname
  65. from airflow.utils.providers_configuration_loader import providers_configuration_loaded
  66. from airflow.utils.session import NEW_SESSION, create_session, provide_session
  67. from airflow.utils.state import DagRunState
  68. from airflow.utils.task_instance_session import set_current_task_instance_session
  69. if TYPE_CHECKING:
  70. from sqlalchemy.orm.session import Session
  71. from airflow.models.operator import Operator
  72. from airflow.serialization.pydantic.dag_run import DagRunPydantic
  73. log = logging.getLogger(__name__)
  74. CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]]
  75. def _generate_temporary_run_id() -> str:
  76. """
  77. Generate a ``run_id`` for a DAG run that will be created temporarily.
  78. This is used mostly by ``airflow task test`` to create a DAG run that will
  79. be deleted after the task is run.
  80. """
  81. return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__"
  82. def _get_dag_run(
  83. *,
  84. dag: DAG,
  85. create_if_necessary: CreateIfNecessary,
  86. exec_date_or_run_id: str | None = None,
  87. session: Session | None = None,
  88. ) -> tuple[DagRun | DagRunPydantic, bool]:
  89. """
  90. Try to retrieve a DAG run from a string representing either a run ID or logical date.
  91. This checks DAG runs like this:
  92. 1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run.
  93. 2. Try to parse the input as a date. If that works, and the resulting
  94. date matches a DAG run's logical date, return the run.
  95. 3. If ``create_if_necessary`` is *False* and the input works for neither of
  96. the above, raise ``DagRunNotFound``.
  97. 4. Try to create a new DAG run. If the input looks like a date, use it as
  98. the logical date; otherwise use it as a run ID and set the logical date
  99. to the current time.
  100. """
  101. if not exec_date_or_run_id and not create_if_necessary:
  102. raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
  103. execution_date: pendulum.DateTime | None = None
  104. if exec_date_or_run_id:
  105. dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=exec_date_or_run_id, session=session)
  106. if dag_run:
  107. return dag_run, False
  108. with suppress(ParserError, TypeError):
  109. execution_date = timezone.parse(exec_date_or_run_id)
  110. if execution_date:
  111. dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, execution_date=execution_date, session=session)
  112. if dag_run:
  113. return dag_run, False
  114. elif not create_if_necessary:
  115. raise DagRunNotFound(
  116. f"DagRun for {dag.dag_id} with run_id or execution_date "
  117. f"of {exec_date_or_run_id!r} not found"
  118. )
  119. if execution_date is not None:
  120. dag_run_execution_date = execution_date
  121. else:
  122. dag_run_execution_date = pendulum.instance(timezone.utcnow())
  123. if create_if_necessary == "memory":
  124. dag_run = DagRun(
  125. dag_id=dag.dag_id,
  126. run_id=exec_date_or_run_id,
  127. execution_date=dag_run_execution_date,
  128. data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_execution_date),
  129. )
  130. return dag_run, True
  131. elif create_if_necessary == "db":
  132. dag_run = dag.create_dagrun(
  133. state=DagRunState.QUEUED,
  134. execution_date=dag_run_execution_date,
  135. run_id=_generate_temporary_run_id(),
  136. data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_execution_date),
  137. session=session,
  138. )
  139. return dag_run, True
  140. raise ValueError(f"unknown create_if_necessary value: {create_if_necessary!r}")
  141. @internal_api_call
  142. @provide_session
  143. def _get_ti_db_access(
  144. dag: DAG,
  145. task: Operator,
  146. map_index: int,
  147. *,
  148. exec_date_or_run_id: str | None = None,
  149. pool: str | None = None,
  150. create_if_necessary: CreateIfNecessary = False,
  151. session: Session = NEW_SESSION,
  152. ) -> tuple[TaskInstance | TaskInstancePydantic, bool]:
  153. """Get the task instance through DagRun.run_id, if that fails, get the TI the old way."""
  154. # this check is imperfect because diff dags could have tasks with same name
  155. # but in a task, dag_id is a property that accesses its dag, and we don't
  156. # currently include the dag when serializing an operator
  157. if task.task_id not in dag.task_dict:
  158. raise ValueError(f"Provided task {task.task_id} is not in dag '{dag.dag_id}.")
  159. if not exec_date_or_run_id and not create_if_necessary:
  160. raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
  161. if task.get_needs_expansion():
  162. if map_index < 0:
  163. raise RuntimeError("No map_index passed to mapped task")
  164. elif map_index >= 0:
  165. raise RuntimeError("map_index passed to non-mapped task")
  166. dag_run, dr_created = _get_dag_run(
  167. dag=dag,
  168. exec_date_or_run_id=exec_date_or_run_id,
  169. create_if_necessary=create_if_necessary,
  170. session=session,
  171. )
  172. ti_or_none = dag_run.get_task_instance(task.task_id, map_index=map_index, session=session)
  173. ti: TaskInstance | TaskInstancePydantic
  174. if ti_or_none is None:
  175. if not create_if_necessary:
  176. raise TaskInstanceNotFound(
  177. f"TaskInstance for {dag.dag_id}, {task.task_id}, map={map_index} with "
  178. f"run_id or execution_date of {exec_date_or_run_id!r} not found"
  179. )
  180. # TODO: Validate map_index is in range?
  181. ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
  182. if dag_run in session:
  183. session.add(ti)
  184. ti.dag_run = dag_run
  185. else:
  186. ti = ti_or_none
  187. ti.refresh_from_task(task, pool_override=pool)
  188. return ti, dr_created
  189. def _get_ti(
  190. task: Operator,
  191. map_index: int,
  192. *,
  193. exec_date_or_run_id: str | None = None,
  194. pool: str | None = None,
  195. create_if_necessary: CreateIfNecessary = False,
  196. ):
  197. dag = task.dag
  198. if dag is None:
  199. raise ValueError("Cannot get task instance for a task not assigned to a DAG")
  200. ti, dr_created = _get_ti_db_access(
  201. dag=dag,
  202. task=task,
  203. map_index=map_index,
  204. exec_date_or_run_id=exec_date_or_run_id,
  205. pool=pool,
  206. create_if_necessary=create_if_necessary,
  207. )
  208. # we do refresh_from_task so that if TI has come back via RPC, we ensure that ti.task
  209. # is the original task object and not the result of the round trip
  210. ti.refresh_from_task(task, pool_override=pool)
  211. return ti, dr_created
  212. def _run_task_by_selected_method(
  213. args, dag: DAG, ti: TaskInstance | TaskInstancePydantic
  214. ) -> None | TaskReturnCode:
  215. """
  216. Run the task based on a mode.
  217. Any of the 3 modes are available:
  218. - using LocalTaskJob
  219. - as raw task
  220. - by executor
  221. """
  222. if TYPE_CHECKING:
  223. assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete
  224. if args.local:
  225. return _run_task_by_local_task_job(args, ti)
  226. if args.raw:
  227. return _run_raw_task(args, ti)
  228. _run_task_by_executor(args, dag, ti)
  229. return None
  230. def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
  231. """
  232. Send the task to the executor for execution.
  233. This can result in the task being started by another host if the executor implementation does.
  234. """
  235. pickle_id = None
  236. if args.ship_dag:
  237. try:
  238. # Running remotely, so pickling the DAG
  239. with create_session() as session:
  240. pickle = DagPickle(dag)
  241. session.add(pickle)
  242. pickle_id = pickle.id
  243. # TODO: This should be written to a log
  244. print(f"Pickled dag {dag} as pickle_id: {pickle_id}")
  245. except Exception as e:
  246. print("Could not pickle the DAG")
  247. print(e)
  248. raise e
  249. if ti.executor:
  250. executor = ExecutorLoader.load_executor(ti.executor)
  251. else:
  252. executor = ExecutorLoader.get_default_executor()
  253. executor.job_id = None
  254. executor.start()
  255. print("Sending to executor.")
  256. executor.queue_task_instance(
  257. ti,
  258. mark_success=args.mark_success,
  259. pickle_id=pickle_id,
  260. ignore_all_deps=args.ignore_all_dependencies,
  261. ignore_depends_on_past=should_ignore_depends_on_past(args),
  262. wait_for_past_depends_before_skipping=(args.depends_on_past == "wait"),
  263. ignore_task_deps=args.ignore_dependencies,
  264. ignore_ti_state=args.force,
  265. pool=args.pool,
  266. )
  267. executor.heartbeat()
  268. executor.end()
  269. def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic) -> TaskReturnCode | None:
  270. """Run LocalTaskJob, which monitors the raw task execution process."""
  271. if InternalApiConfig.get_use_internal_api():
  272. from airflow.models.renderedtifields import RenderedTaskInstanceFields # noqa: F401
  273. from airflow.models.trigger import Trigger # noqa: F401
  274. job_runner = LocalTaskJobRunner(
  275. job=Job(dag_id=ti.dag_id),
  276. task_instance=ti,
  277. mark_success=args.mark_success,
  278. pickle_id=args.pickle,
  279. ignore_all_deps=args.ignore_all_dependencies,
  280. ignore_depends_on_past=should_ignore_depends_on_past(args),
  281. wait_for_past_depends_before_skipping=(args.depends_on_past == "wait"),
  282. ignore_task_deps=args.ignore_dependencies,
  283. ignore_ti_state=args.force,
  284. pool=args.pool,
  285. external_executor_id=_extract_external_executor_id(args),
  286. )
  287. try:
  288. ret = run_job(job=job_runner.job, execute_callable=job_runner._execute)
  289. finally:
  290. if args.shut_down_logging:
  291. logging.shutdown()
  292. with suppress(ValueError):
  293. return TaskReturnCode(ret)
  294. return None
  295. RAW_TASK_UNSUPPORTED_OPTION = [
  296. "ignore_all_dependencies",
  297. "ignore_depends_on_past",
  298. "ignore_dependencies",
  299. "force",
  300. ]
  301. def _run_raw_task(args, ti: TaskInstance) -> None | TaskReturnCode:
  302. """Run the main task handling code."""
  303. return ti._run_raw_task(
  304. mark_success=args.mark_success,
  305. job_id=args.job_id,
  306. pool=args.pool,
  307. )
  308. def _extract_external_executor_id(args) -> str | None:
  309. if hasattr(args, "external_executor_id"):
  310. return getattr(args, "external_executor_id")
  311. return os.environ.get("external_executor_id", None)
  312. @contextmanager
  313. def _move_task_handlers_to_root(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]:
  314. """
  315. Move handlers for task logging to root logger.
  316. We want anything logged during task run to be propagated to task log handlers.
  317. If running in a k8s executor pod, also keep the stream handler on root logger
  318. so that logs are still emitted to stdout.
  319. """
  320. # nothing to do
  321. if not ti.log.handlers or settings.DONOT_MODIFY_HANDLERS:
  322. yield
  323. return
  324. # Move task handlers to root and reset task logger and restore original logger settings after exit.
  325. # If k8s executor, we need to ensure that root logger has a console handler, so that
  326. # task logs propagate to stdout (this is how webserver retrieves them while task is running).
  327. root_logger = logging.getLogger()
  328. console_handler = next((h for h in root_logger.handlers if h.name == "console"), None)
  329. with LoggerMutationHelper(root_logger), LoggerMutationHelper(ti.log) as task_helper:
  330. task_helper.move(root_logger)
  331. if IS_K8S_EXECUTOR_POD or IS_EXECUTOR_CONTAINER:
  332. if console_handler and console_handler not in root_logger.handlers:
  333. root_logger.addHandler(console_handler)
  334. yield
  335. @contextmanager
  336. def _redirect_stdout_to_ti_log(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]:
  337. """
  338. Redirect stdout to ti logger.
  339. Redirect stdout and stderr to the task instance log as INFO and WARNING
  340. level messages, respectively.
  341. If stdout already redirected (possible when task running with option
  342. `--local`), don't redirect again.
  343. """
  344. # if sys.stdout is StreamLogWriter, it means we already redirected
  345. # likely before forking in LocalTaskJob
  346. if not isinstance(sys.stdout, StreamLogWriter):
  347. info_writer = StreamLogWriter(ti.log, logging.INFO)
  348. warning_writer = StreamLogWriter(ti.log, logging.WARNING)
  349. with redirect_stdout(info_writer), redirect_stderr(warning_writer):
  350. yield
  351. else:
  352. yield
  353. class TaskCommandMarker:
  354. """Marker for listener hooks, to properly detect from which component they are called."""
  355. @cli_utils.action_cli(check_db=False)
  356. def task_run(args, dag: DAG | None = None) -> TaskReturnCode | None:
  357. """
  358. Run a single task instance.
  359. Note that there must be at least one DagRun for this to start,
  360. i.e. it must have been scheduled and/or triggered previously.
  361. Alternatively, if you just need to run it for testing then use
  362. "airflow tasks test ..." command instead.
  363. """
  364. # Load custom airflow config
  365. if args.local and args.raw:
  366. raise AirflowException(
  367. "Option --raw and --local are mutually exclusive. "
  368. "Please remove one option to execute the command."
  369. )
  370. if args.raw:
  371. unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
  372. if unsupported_options:
  373. unsupported_raw_task_flags = ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION)
  374. unsupported_flags = ", ".join(f"--{o}" for o in unsupported_options)
  375. raise AirflowException(
  376. "Option --raw does not work with some of the other options on this command. "
  377. "You can't use --raw option and the following options: "
  378. f"{unsupported_raw_task_flags}. "
  379. f"You provided the option {unsupported_flags}. "
  380. "Delete it to execute the command."
  381. )
  382. if dag and args.pickle:
  383. raise AirflowException("You cannot use the --pickle option when using DAG.cli() method.")
  384. if args.cfg_path:
  385. with open(args.cfg_path) as conf_file:
  386. conf_dict = json.load(conf_file)
  387. if os.path.exists(args.cfg_path):
  388. os.remove(args.cfg_path)
  389. conf.read_dict(conf_dict, source=args.cfg_path)
  390. settings.configure_vars()
  391. settings.MASK_SECRETS_IN_LOGS = True
  392. get_listener_manager().hook.on_starting(component=TaskCommandMarker())
  393. if args.pickle:
  394. print(f"Loading pickle id: {args.pickle}")
  395. _dag = get_dag_by_pickle(args.pickle)
  396. elif not dag:
  397. _dag = get_dag(args.subdir, args.dag_id, args.read_from_db)
  398. else:
  399. _dag = dag
  400. task = _dag.get_task(task_id=args.task_id)
  401. ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool)
  402. ti.init_run_context(raw=args.raw)
  403. hostname = get_hostname()
  404. log.info("Running %s on host %s", ti, hostname)
  405. if not InternalApiConfig.get_use_internal_api():
  406. # IMPORTANT, have to re-configure ORM with the NullPool, otherwise, each "run" command may leave
  407. # behind multiple open sleeping connections while heartbeating, which could
  408. # easily exceed the database connection limit when
  409. # processing hundreds of simultaneous tasks.
  410. # this should be last thing before running, to reduce likelihood of an open session
  411. # which can cause trouble if running process in a fork.
  412. settings.reconfigure_orm(disable_connection_pool=True)
  413. task_return_code = None
  414. try:
  415. if args.interactive:
  416. task_return_code = _run_task_by_selected_method(args, _dag, ti)
  417. else:
  418. with _move_task_handlers_to_root(ti), _redirect_stdout_to_ti_log(ti):
  419. task_return_code = _run_task_by_selected_method(args, _dag, ti)
  420. if task_return_code == TaskReturnCode.DEFERRED:
  421. _set_task_deferred_context_var()
  422. finally:
  423. try:
  424. get_listener_manager().hook.before_stopping(component=TaskCommandMarker())
  425. except Exception:
  426. pass
  427. return task_return_code
  428. @cli_utils.action_cli(check_db=False)
  429. @providers_configuration_loaded
  430. def task_failed_deps(args) -> None:
  431. """
  432. Get task instance dependencies that were not met.
  433. Returns the unmet dependencies for a task instance from the perspective of the
  434. scheduler (i.e. why a task instance doesn't get scheduled and then queued by the
  435. scheduler, and then run by an executor).
  436. >>> airflow tasks failed-deps tutorial sleep 2015-01-01
  437. Task instance dependencies not met:
  438. Dagrun Running: Task instance's dagrun did not exist: Unknown reason
  439. Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks
  440. to have succeeded, but found 1 non-success(es).
  441. """
  442. dag = get_dag(args.subdir, args.dag_id)
  443. task = dag.get_task(task_id=args.task_id)
  444. ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
  445. # tasks_failed-deps is executed with access to the database.
  446. if isinstance(ti, TaskInstancePydantic):
  447. raise ValueError("not a TaskInstance")
  448. dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
  449. failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
  450. # TODO, Do we want to print or log this
  451. if failed_deps:
  452. print("Task instance dependencies not met:")
  453. for dep in failed_deps:
  454. print(f"{dep.dep_name}: {dep.reason}")
  455. else:
  456. print("Task instance dependencies are all met.")
  457. @cli_utils.action_cli(check_db=False)
  458. @suppress_logs_and_warning
  459. @providers_configuration_loaded
  460. def task_state(args) -> None:
  461. """
  462. Return the state of a TaskInstance at the command line.
  463. >>> airflow tasks state tutorial sleep 2015-01-01
  464. success
  465. """
  466. dag = get_dag(args.subdir, args.dag_id)
  467. task = dag.get_task(task_id=args.task_id)
  468. ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
  469. # task_state is executed with access to the database.
  470. if isinstance(ti, TaskInstancePydantic):
  471. raise ValueError("not a TaskInstance")
  472. print(ti.current_state())
  473. @cli_utils.action_cli(check_db=False)
  474. @suppress_logs_and_warning
  475. @providers_configuration_loaded
  476. def task_list(args, dag: DAG | None = None) -> None:
  477. """List the tasks within a DAG at the command line."""
  478. dag = dag or get_dag(args.subdir, args.dag_id)
  479. if args.tree:
  480. dag.tree_view()
  481. else:
  482. tasks = sorted(t.task_id for t in dag.tasks)
  483. print("\n".join(tasks))
  484. class _SupportedDebugger(Protocol):
  485. def post_mortem(self) -> None: ...
  486. SUPPORTED_DEBUGGER_MODULES = [
  487. "pudb",
  488. "web_pdb",
  489. "ipdb",
  490. "pdb",
  491. ]
  492. def _guess_debugger() -> _SupportedDebugger:
  493. """
  494. Try to guess the debugger used by the user.
  495. When it doesn't find any user-installed debugger, returns ``pdb``.
  496. List of supported debuggers:
  497. * `pudb <https://github.com/inducer/pudb>`__
  498. * `web_pdb <https://github.com/romanvm/python-web-pdb>`__
  499. * `ipdb <https://github.com/gotcha/ipdb>`__
  500. * `pdb <https://docs.python.org/3/library/pdb.html>`__
  501. """
  502. exc: Exception
  503. for mod_name in SUPPORTED_DEBUGGER_MODULES:
  504. try:
  505. return cast(_SupportedDebugger, importlib.import_module(mod_name))
  506. except ImportError as e:
  507. exc = e
  508. raise exc
  509. @cli_utils.action_cli(check_db=False)
  510. @suppress_logs_and_warning
  511. @providers_configuration_loaded
  512. @provide_session
  513. def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None:
  514. """Get the status of all task instances in a DagRun."""
  515. dag_run = session.scalar(
  516. select(DagRun).where(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id == args.dag_id)
  517. )
  518. if not dag_run:
  519. try:
  520. execution_date = timezone.parse(args.execution_date_or_run_id)
  521. dag_run = session.scalar(
  522. select(DagRun).where(DagRun.execution_date == execution_date, DagRun.dag_id == args.dag_id)
  523. )
  524. except (ParserError, TypeError) as err:
  525. raise AirflowException(f"Error parsing the supplied execution_date. Error: {err}")
  526. if dag_run is None:
  527. raise DagRunNotFound(
  528. f"DagRun for {args.dag_id} with run_id or execution_date of {args.execution_date_or_run_id!r} "
  529. "not found"
  530. )
  531. has_mapped_instances = any(ti.map_index >= 0 for ti in dag_run.task_instances)
  532. def format_task_instance(ti: TaskInstance) -> dict[str, str]:
  533. data = {
  534. "dag_id": ti.dag_id,
  535. "execution_date": dag_run.execution_date.isoformat(),
  536. "task_id": ti.task_id,
  537. "state": ti.state,
  538. "start_date": ti.start_date.isoformat() if ti.start_date else "",
  539. "end_date": ti.end_date.isoformat() if ti.end_date else "",
  540. }
  541. if has_mapped_instances:
  542. data["map_index"] = str(ti.map_index) if ti.map_index >= 0 else ""
  543. return data
  544. AirflowConsole().print_as(data=dag_run.task_instances, output=args.output, mapper=format_task_instance)
  545. @cli_utils.action_cli(check_db=False)
  546. @provide_session
  547. def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
  548. """Test task for a given dag_id."""
  549. # We want to log output from operators etc to show up here. Normally
  550. # airflow.task would redirect to a file, but here we want it to propagate
  551. # up to the normal airflow handler.
  552. settings.MASK_SECRETS_IN_LOGS = True
  553. handlers = logging.getLogger("airflow.task").handlers
  554. already_has_stream_handler = False
  555. for handler in handlers:
  556. already_has_stream_handler = isinstance(handler, logging.StreamHandler)
  557. if already_has_stream_handler:
  558. break
  559. if not already_has_stream_handler:
  560. logging.getLogger("airflow.task").propagate = True
  561. env_vars = {"AIRFLOW_TEST_MODE": "True"}
  562. if args.env_vars:
  563. env_vars.update(args.env_vars)
  564. os.environ.update(env_vars)
  565. dag = dag or get_dag(args.subdir, args.dag_id)
  566. task = dag.get_task(task_id=args.task_id)
  567. # Add CLI provided task_params to task.params
  568. if args.task_params:
  569. passed_in_params = json.loads(args.task_params)
  570. task.params.update(passed_in_params)
  571. if task.params and isinstance(task.params, ParamsDict):
  572. task.params.validate()
  573. ti, dr_created = _get_ti(
  574. task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db"
  575. )
  576. # task_test is executed with access to the database.
  577. if isinstance(ti, TaskInstancePydantic):
  578. raise ValueError("not a TaskInstance")
  579. try:
  580. with redirect_stdout(RedactedIO()):
  581. if args.dry_run:
  582. ti.dry_run()
  583. else:
  584. ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True)
  585. except TaskDeferred as defer:
  586. ti.defer_task(exception=defer, session=session)
  587. log.info("[TASK TEST] running trigger in line")
  588. event = _run_inline_trigger(defer.trigger)
  589. ti.next_method = defer.method_name
  590. ti.next_kwargs = {"event": event.payload} if event else defer.kwargs
  591. execute_callable = getattr(task, ti.next_method)
  592. if ti.next_kwargs:
  593. execute_callable = functools.partial(execute_callable, **ti.next_kwargs)
  594. context = ti.get_template_context(ignore_param_exceptions=False)
  595. execute_callable(context)
  596. log.info("[TASK TEST] Trigger completed")
  597. except Exception:
  598. if args.post_mortem:
  599. debugger = _guess_debugger()
  600. debugger.post_mortem()
  601. else:
  602. raise
  603. finally:
  604. if not already_has_stream_handler:
  605. # Make sure to reset back to normal. When run for CLI this doesn't
  606. # matter, but it does for test suite
  607. logging.getLogger("airflow.task").propagate = False
  608. if dr_created:
  609. with create_session() as session:
  610. session.delete(ti.dag_run)
  611. @cli_utils.action_cli(check_db=False)
  612. @suppress_logs_and_warning
  613. @providers_configuration_loaded
  614. def task_render(args, dag: DAG | None = None) -> None:
  615. """Render and displays templated fields for a given task."""
  616. if not dag:
  617. dag = get_dag(args.subdir, args.dag_id)
  618. task = dag.get_task(task_id=args.task_id)
  619. ti, _ = _get_ti(
  620. task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
  621. )
  622. # task_render is executed with access to the database.
  623. if isinstance(ti, TaskInstancePydantic):
  624. raise ValueError("not a TaskInstance")
  625. with create_session() as session, set_current_task_instance_session(session=session):
  626. ti.render_templates()
  627. for attr in task.template_fields:
  628. print(
  629. textwrap.dedent(
  630. f""" # ----------------------------------------------------------
  631. # property: {attr}
  632. # ----------------------------------------------------------
  633. {getattr(ti.task, attr)}
  634. """
  635. )
  636. )
  637. @cli_utils.action_cli(check_db=False)
  638. @providers_configuration_loaded
  639. def task_clear(args) -> None:
  640. """Clear all task instances or only those matched by regex for a DAG(s)."""
  641. logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
  642. if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex:
  643. dags = [get_dag_by_file_location(args.dag_id)]
  644. else:
  645. # todo clear command only accepts a single dag_id. no reason for get_dags with 's' except regex?
  646. dags = get_dags(args.subdir, args.dag_id, use_regex=args.dag_regex)
  647. if args.task_regex:
  648. for idx, dag in enumerate(dags):
  649. dags[idx] = dag.partial_subset(
  650. task_ids_or_regex=args.task_regex,
  651. include_downstream=args.downstream,
  652. include_upstream=args.upstream,
  653. )
  654. DAG.clear_dags(
  655. dags,
  656. start_date=args.start_date,
  657. end_date=args.end_date,
  658. only_failed=args.only_failed,
  659. only_running=args.only_running,
  660. confirm_prompt=not args.yes,
  661. include_subdags=not args.exclude_subdags,
  662. include_parentdag=not args.exclude_parentdag,
  663. )
  664. class LoggerMutationHelper:
  665. """
  666. Helper for moving and resetting handlers and other logger attrs.
  667. :meta private:
  668. """
  669. def __init__(self, logger: logging.Logger) -> None:
  670. self.handlers = logger.handlers[:]
  671. self.level = logger.level
  672. self.propagate = logger.propagate
  673. self.source_logger = logger
  674. def apply(self, logger: logging.Logger, replace: bool = True) -> None:
  675. """
  676. Set ``logger`` with attrs stored on instance.
  677. If ``logger`` is root logger, don't change propagate.
  678. """
  679. if replace:
  680. logger.handlers[:] = self.handlers
  681. else:
  682. for h in self.handlers:
  683. if h not in logger.handlers:
  684. logger.addHandler(h)
  685. logger.level = self.level
  686. if logger is not logging.getLogger():
  687. logger.propagate = self.propagate
  688. def move(self, logger: logging.Logger, replace: bool = True) -> None:
  689. """
  690. Replace ``logger`` attrs with those from source.
  691. :param logger: target logger
  692. :param replace: if True, remove all handlers from target first; otherwise add if not present.
  693. """
  694. self.apply(logger, replace=replace)
  695. self.source_logger.propagate = True
  696. self.source_logger.handlers[:] = []
  697. def reset(self) -> None:
  698. self.apply(self.source_logger)
  699. def __enter__(self) -> LoggerMutationHelper:
  700. return self
  701. def __exit__(self, exc_type, exc_val, exc_tb) -> None:
  702. self.reset()