123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- #!/usr/bin/env python
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Renderer DAG (tasks and dependencies) to the graphviz object."""
- from __future__ import annotations
- import warnings
- from typing import TYPE_CHECKING, Any
- try:
- import graphviz
- except ImportError:
- warnings.warn(
- "Could not import graphviz. Rendering graph to the graphical format will not be possible.",
- UserWarning,
- stacklevel=2,
- )
- graphviz = None
- from airflow.exceptions import AirflowException
- from airflow.models.baseoperator import BaseOperator
- from airflow.models.mappedoperator import MappedOperator
- from airflow.utils.dag_edges import dag_edges
- from airflow.utils.state import State
- from airflow.utils.task_group import TaskGroup
- if TYPE_CHECKING:
- from airflow.models import TaskInstance
- from airflow.models.dag import DAG
- from airflow.models.taskmixin import DependencyMixin
- from airflow.serialization.dag_dependency import DagDependency
- def _refine_color(color: str):
- """
- Convert color in #RGB (12 bits) format to #RRGGBB (32 bits), if it possible.
- Otherwise, it returns the original value. Graphviz does not support colors in #RGB format.
- :param color: Text representation of color
- :return: Refined representation of color
- """
- if len(color) == 4 and color[0] == "#":
- color_r = color[1]
- color_g = color[2]
- color_b = color[3]
- return "#" + color_r + color_r + color_g + color_g + color_b + color_b
- return color
- def _draw_task(
- task: MappedOperator | BaseOperator,
- parent_graph: graphviz.Digraph,
- states_by_task_id: dict[Any, Any] | None,
- ) -> None:
- """Draw a single task on the given parent_graph."""
- if states_by_task_id:
- state = states_by_task_id.get(task.task_id)
- color = State.color_fg(state)
- fill_color = State.color(state)
- else:
- color = task.ui_fgcolor
- fill_color = task.ui_color
- parent_graph.node(
- task.task_id,
- _attributes={
- "label": task.label,
- "shape": "rectangle",
- "style": "filled,rounded",
- "color": _refine_color(color),
- "fillcolor": _refine_color(fill_color),
- },
- )
- def _draw_task_group(
- task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
- ) -> None:
- """Draw the given task_group and its children on the given parent_graph."""
- # Draw joins
- if task_group.upstream_group_ids or task_group.upstream_task_ids:
- parent_graph.node(
- task_group.upstream_join_id,
- _attributes={
- "label": "",
- "shape": "circle",
- "style": "filled,rounded",
- "color": _refine_color(task_group.ui_fgcolor),
- "fillcolor": _refine_color(task_group.ui_color),
- "width": "0.2",
- "height": "0.2",
- },
- )
- if task_group.downstream_group_ids or task_group.downstream_task_ids:
- parent_graph.node(
- task_group.downstream_join_id,
- _attributes={
- "label": "",
- "shape": "circle",
- "style": "filled,rounded",
- "color": _refine_color(task_group.ui_fgcolor),
- "fillcolor": _refine_color(task_group.ui_color),
- "width": "0.2",
- "height": "0.2",
- },
- )
- # Draw children
- for child in sorted(task_group.children.values(), key=lambda t: t.node_id if t.node_id else ""):
- _draw_nodes(child, parent_graph, states_by_task_id)
- def _draw_nodes(
- node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
- ) -> None:
- """Draw the node and its children on the given parent_graph recursively."""
- if isinstance(node, (BaseOperator, MappedOperator)):
- _draw_task(node, parent_graph, states_by_task_id)
- else:
- if not isinstance(node, TaskGroup):
- raise AirflowException(f"The node {node} should be TaskGroup and is not")
- # Draw TaskGroup
- if node.is_root:
- # No need to draw background for root TaskGroup.
- _draw_task_group(node, parent_graph, states_by_task_id)
- else:
- with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub:
- sub.attr(
- shape="rectangle",
- style="filled",
- color=_refine_color(node.ui_fgcolor),
- # Partially transparent CornflowerBlue
- fillcolor="#6495ed7f",
- label=node.label,
- )
- _draw_task_group(node, sub, states_by_task_id)
- def render_dag_dependencies(deps: dict[str, list[DagDependency]]) -> graphviz.Digraph:
- """
- Render the DAG dependency to the DOT object.
- :param deps: List of DAG dependencies
- :return: Graphviz object
- """
- if not graphviz:
- raise AirflowException(
- "Could not import graphviz. Install the graphviz python package to fix this error."
- )
- dot = graphviz.Digraph(graph_attr={"rankdir": "LR"})
- for dag, dependencies in deps.items():
- for dep in dependencies:
- with dot.subgraph(
- name=dag,
- graph_attr={
- "rankdir": "LR",
- "labelloc": "t",
- "label": dag,
- },
- ) as dep_subgraph:
- dep_subgraph.edge(dep.source, dep.dependency_id)
- dep_subgraph.edge(dep.dependency_id, dep.target)
- return dot
- def render_dag(dag: DAG, tis: list[TaskInstance] | None = None) -> graphviz.Digraph:
- """
- Render the DAG object to the DOT object.
- If an task instance list is passed, the nodes will be painted according to task statuses.
- :param dag: DAG that will be rendered.
- :param tis: List of task instances
- :return: Graphviz object
- """
- if not graphviz:
- raise AirflowException(
- "Could not import graphviz. Install the graphviz python package to fix this error."
- )
- dot = graphviz.Digraph(
- dag.dag_id,
- graph_attr={
- "rankdir": dag.orientation if dag.orientation else "LR",
- "labelloc": "t",
- "label": dag.dag_id,
- },
- )
- states_by_task_id = None
- if tis is not None:
- states_by_task_id = {ti.task_id: ti.state for ti in tis}
- _draw_nodes(dag.task_group, dot, states_by_task_id)
- for edge in dag_edges(dag):
- # Gets an optional label for the edge; this will be None if none is specified.
- label = dag.get_edge_info(edge["source_id"], edge["target_id"]).get("label")
- # Add the edge to the graph with optional label
- # (we can just use the maybe-None label variable directly)
- dot.edge(edge["source_id"], edge["target_id"], label)
- return dot
|