dot_renderer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #!/usr/bin/env python
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. The ASF licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing,
  14. # software distributed under the License is distributed on an
  15. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. # KIND, either express or implied. See the License for the
  17. # specific language governing permissions and limitations
  18. # under the License.
  19. """Renderer DAG (tasks and dependencies) to the graphviz object."""
  20. from __future__ import annotations
  21. import warnings
  22. from typing import TYPE_CHECKING, Any
  23. try:
  24. import graphviz
  25. except ImportError:
  26. warnings.warn(
  27. "Could not import graphviz. Rendering graph to the graphical format will not be possible.",
  28. UserWarning,
  29. stacklevel=2,
  30. )
  31. graphviz = None
  32. from airflow.exceptions import AirflowException
  33. from airflow.models.baseoperator import BaseOperator
  34. from airflow.models.mappedoperator import MappedOperator
  35. from airflow.utils.dag_edges import dag_edges
  36. from airflow.utils.state import State
  37. from airflow.utils.task_group import TaskGroup
  38. if TYPE_CHECKING:
  39. from airflow.models import TaskInstance
  40. from airflow.models.dag import DAG
  41. from airflow.models.taskmixin import DependencyMixin
  42. from airflow.serialization.dag_dependency import DagDependency
  43. def _refine_color(color: str):
  44. """
  45. Convert color in #RGB (12 bits) format to #RRGGBB (32 bits), if it possible.
  46. Otherwise, it returns the original value. Graphviz does not support colors in #RGB format.
  47. :param color: Text representation of color
  48. :return: Refined representation of color
  49. """
  50. if len(color) == 4 and color[0] == "#":
  51. color_r = color[1]
  52. color_g = color[2]
  53. color_b = color[3]
  54. return "#" + color_r + color_r + color_g + color_g + color_b + color_b
  55. return color
  56. def _draw_task(
  57. task: MappedOperator | BaseOperator,
  58. parent_graph: graphviz.Digraph,
  59. states_by_task_id: dict[Any, Any] | None,
  60. ) -> None:
  61. """Draw a single task on the given parent_graph."""
  62. if states_by_task_id:
  63. state = states_by_task_id.get(task.task_id)
  64. color = State.color_fg(state)
  65. fill_color = State.color(state)
  66. else:
  67. color = task.ui_fgcolor
  68. fill_color = task.ui_color
  69. parent_graph.node(
  70. task.task_id,
  71. _attributes={
  72. "label": task.label,
  73. "shape": "rectangle",
  74. "style": "filled,rounded",
  75. "color": _refine_color(color),
  76. "fillcolor": _refine_color(fill_color),
  77. },
  78. )
  79. def _draw_task_group(
  80. task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
  81. ) -> None:
  82. """Draw the given task_group and its children on the given parent_graph."""
  83. # Draw joins
  84. if task_group.upstream_group_ids or task_group.upstream_task_ids:
  85. parent_graph.node(
  86. task_group.upstream_join_id,
  87. _attributes={
  88. "label": "",
  89. "shape": "circle",
  90. "style": "filled,rounded",
  91. "color": _refine_color(task_group.ui_fgcolor),
  92. "fillcolor": _refine_color(task_group.ui_color),
  93. "width": "0.2",
  94. "height": "0.2",
  95. },
  96. )
  97. if task_group.downstream_group_ids or task_group.downstream_task_ids:
  98. parent_graph.node(
  99. task_group.downstream_join_id,
  100. _attributes={
  101. "label": "",
  102. "shape": "circle",
  103. "style": "filled,rounded",
  104. "color": _refine_color(task_group.ui_fgcolor),
  105. "fillcolor": _refine_color(task_group.ui_color),
  106. "width": "0.2",
  107. "height": "0.2",
  108. },
  109. )
  110. # Draw children
  111. for child in sorted(task_group.children.values(), key=lambda t: t.node_id if t.node_id else ""):
  112. _draw_nodes(child, parent_graph, states_by_task_id)
  113. def _draw_nodes(
  114. node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
  115. ) -> None:
  116. """Draw the node and its children on the given parent_graph recursively."""
  117. if isinstance(node, (BaseOperator, MappedOperator)):
  118. _draw_task(node, parent_graph, states_by_task_id)
  119. else:
  120. if not isinstance(node, TaskGroup):
  121. raise AirflowException(f"The node {node} should be TaskGroup and is not")
  122. # Draw TaskGroup
  123. if node.is_root:
  124. # No need to draw background for root TaskGroup.
  125. _draw_task_group(node, parent_graph, states_by_task_id)
  126. else:
  127. with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub:
  128. sub.attr(
  129. shape="rectangle",
  130. style="filled",
  131. color=_refine_color(node.ui_fgcolor),
  132. # Partially transparent CornflowerBlue
  133. fillcolor="#6495ed7f",
  134. label=node.label,
  135. )
  136. _draw_task_group(node, sub, states_by_task_id)
  137. def render_dag_dependencies(deps: dict[str, list[DagDependency]]) -> graphviz.Digraph:
  138. """
  139. Render the DAG dependency to the DOT object.
  140. :param deps: List of DAG dependencies
  141. :return: Graphviz object
  142. """
  143. if not graphviz:
  144. raise AirflowException(
  145. "Could not import graphviz. Install the graphviz python package to fix this error."
  146. )
  147. dot = graphviz.Digraph(graph_attr={"rankdir": "LR"})
  148. for dag, dependencies in deps.items():
  149. for dep in dependencies:
  150. with dot.subgraph(
  151. name=dag,
  152. graph_attr={
  153. "rankdir": "LR",
  154. "labelloc": "t",
  155. "label": dag,
  156. },
  157. ) as dep_subgraph:
  158. dep_subgraph.edge(dep.source, dep.dependency_id)
  159. dep_subgraph.edge(dep.dependency_id, dep.target)
  160. return dot
  161. def render_dag(dag: DAG, tis: list[TaskInstance] | None = None) -> graphviz.Digraph:
  162. """
  163. Render the DAG object to the DOT object.
  164. If an task instance list is passed, the nodes will be painted according to task statuses.
  165. :param dag: DAG that will be rendered.
  166. :param tis: List of task instances
  167. :return: Graphviz object
  168. """
  169. if not graphviz:
  170. raise AirflowException(
  171. "Could not import graphviz. Install the graphviz python package to fix this error."
  172. )
  173. dot = graphviz.Digraph(
  174. dag.dag_id,
  175. graph_attr={
  176. "rankdir": dag.orientation if dag.orientation else "LR",
  177. "labelloc": "t",
  178. "label": dag.dag_id,
  179. },
  180. )
  181. states_by_task_id = None
  182. if tis is not None:
  183. states_by_task_id = {ti.task_id: ti.state for ti in tis}
  184. _draw_nodes(dag.task_group, dot, states_by_task_id)
  185. for edge in dag_edges(dag):
  186. # Gets an optional label for the edge; this will be None if none is specified.
  187. label = dag.get_edge_info(edge["source_id"], edge["target_id"]).get("label")
  188. # Add the edge to the graph with optional label
  189. # (we can just use the maybe-None label variable directly)
  190. dot.edge(edge["source_id"], edge["target_id"], label)
  191. return dot