dag_cycle_tester.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. """DAG Cycle tester."""
  18. from __future__ import annotations
  19. from collections import defaultdict, deque
  20. from typing import TYPE_CHECKING
  21. from airflow.exceptions import AirflowDagCycleException, RemovedInAirflow3Warning
  22. if TYPE_CHECKING:
  23. from airflow.models.dag import DAG
  24. CYCLE_NEW = 0
  25. CYCLE_IN_PROGRESS = 1
  26. CYCLE_DONE = 2
  27. def test_cycle(dag: DAG) -> None:
  28. """
  29. A wrapper function of `check_cycle` for backward compatibility purpose.
  30. New code should use `check_cycle` instead since this function name `test_cycle` starts
  31. with 'test_' and will be considered as a unit test by pytest, resulting in failure.
  32. """
  33. from warnings import warn
  34. warn(
  35. "Deprecated, please use `check_cycle` at the same module instead.",
  36. RemovedInAirflow3Warning,
  37. stacklevel=2,
  38. )
  39. return check_cycle(dag)
  40. def check_cycle(dag: DAG) -> None:
  41. """
  42. Check to see if there are any cycles in the DAG.
  43. :raises AirflowDagCycleException: If cycle is found in the DAG.
  44. """
  45. # default of int is 0 which corresponds to CYCLE_NEW
  46. visited: dict[str, int] = defaultdict(int)
  47. path_stack: deque[str] = deque()
  48. task_dict = dag.task_dict
  49. def _check_adjacent_tasks(task_id, current_task):
  50. """Return first untraversed child task, else None if all tasks traversed."""
  51. for adjacent_task in current_task.get_direct_relative_ids():
  52. if visited[adjacent_task] == CYCLE_IN_PROGRESS:
  53. msg = f"Cycle detected in DAG: {dag.dag_id}. Faulty task: {task_id}"
  54. raise AirflowDagCycleException(msg)
  55. elif visited[adjacent_task] == CYCLE_NEW:
  56. return adjacent_task
  57. return None
  58. for dag_task_id in dag.task_dict.keys():
  59. if visited[dag_task_id] == CYCLE_DONE:
  60. continue
  61. path_stack.append(dag_task_id)
  62. while path_stack:
  63. current_task_id = path_stack[-1]
  64. if visited[current_task_id] == CYCLE_NEW:
  65. visited[current_task_id] = CYCLE_IN_PROGRESS
  66. task = task_dict[current_task_id]
  67. child_to_check = _check_adjacent_tasks(current_task_id, task)
  68. if not child_to_check:
  69. visited[current_task_id] = CYCLE_DONE
  70. path_stack.pop()
  71. else:
  72. path_stack.append(child_to_check)