priority_strategy.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Priority weight strategies for task scheduling."""
  19. from __future__ import annotations
  20. from abc import ABC, abstractmethod
  21. from typing import TYPE_CHECKING, Any
  22. from airflow.exceptions import AirflowException
  23. if TYPE_CHECKING:
  24. from airflow.models.taskinstance import TaskInstance
  25. class PriorityWeightStrategy(ABC):
  26. """
  27. Priority weight strategy interface.
  28. This feature is experimental and subject to change at any time.
  29. Currently, we don't serialize the priority weight strategy parameters. This means that
  30. the priority weight strategy must be stateless, but you can add class attributes, and
  31. create multiple subclasses with different attributes values if you need to create
  32. different versions of the same strategy.
  33. """
  34. @abstractmethod
  35. def get_weight(self, ti: TaskInstance):
  36. """Get the priority weight of a task."""
  37. ...
  38. @classmethod
  39. def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
  40. """
  41. Deserialize a priority weight strategy from data.
  42. This is called when a serialized DAG is deserialized. ``data`` will be whatever
  43. was returned by ``serialize`` during DAG serialization. The default
  44. implementation constructs the priority weight strategy without any arguments.
  45. """
  46. return cls(**data) # type: ignore[call-arg]
  47. def serialize(self) -> dict[str, Any]:
  48. """
  49. Serialize the priority weight strategy for JSON encoding.
  50. This is called during DAG serialization to store priority weight strategy information
  51. in the database. This should return a JSON-serializable dict that will be fed into
  52. ``deserialize`` when the DAG is deserialized. The default implementation returns
  53. an empty dict.
  54. """
  55. return {}
  56. def __eq__(self, other: object) -> bool:
  57. """Equality comparison."""
  58. if not isinstance(other, type(self)):
  59. return False
  60. return self.serialize() == other.serialize()
  61. class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
  62. """Priority weight strategy that uses the task's priority weight directly."""
  63. def get_weight(self, ti: TaskInstance):
  64. if TYPE_CHECKING:
  65. assert ti.task
  66. return ti.task.priority_weight
  67. class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
  68. """Priority weight strategy that uses the sum of the priority weights of all downstream tasks."""
  69. def get_weight(self, ti: TaskInstance) -> int:
  70. if TYPE_CHECKING:
  71. assert ti.task
  72. dag = ti.task.get_dag()
  73. if dag is None:
  74. return ti.task.priority_weight
  75. return ti.task.priority_weight + sum(
  76. dag.task_dict[task_id].priority_weight
  77. for task_id in ti.task.get_flat_relative_ids(upstream=False)
  78. )
  79. class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
  80. """Priority weight strategy that uses the sum of the priority weights of all upstream tasks."""
  81. def get_weight(self, ti: TaskInstance):
  82. if TYPE_CHECKING:
  83. assert ti.task
  84. dag = ti.task.get_dag()
  85. if dag is None:
  86. return ti.task.priority_weight
  87. return ti.task.priority_weight + sum(
  88. dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True)
  89. )
  90. airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
  91. "absolute": _AbsolutePriorityWeightStrategy,
  92. "downstream": _DownstreamPriorityWeightStrategy,
  93. "upstream": _UpstreamPriorityWeightStrategy,
  94. }
  95. airflow_priority_weight_strategies_classes = {
  96. cls: name for name, cls in airflow_priority_weight_strategies.items()
  97. }
  98. def validate_and_load_priority_weight_strategy(
  99. priority_weight_strategy: str | PriorityWeightStrategy | None,
  100. ) -> PriorityWeightStrategy:
  101. """
  102. Validate and load a priority weight strategy.
  103. Returns the priority weight strategy if it is valid, otherwise raises an exception.
  104. :param priority_weight_strategy: The priority weight strategy to validate and load.
  105. :meta private:
  106. """
  107. from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy
  108. from airflow.utils.module_loading import qualname
  109. if priority_weight_strategy is None:
  110. return _AbsolutePriorityWeightStrategy()
  111. if isinstance(priority_weight_strategy, str):
  112. if priority_weight_strategy in airflow_priority_weight_strategies:
  113. return airflow_priority_weight_strategies[priority_weight_strategy]()
  114. priority_weight_strategy_class = priority_weight_strategy
  115. else:
  116. priority_weight_strategy_class = qualname(priority_weight_strategy)
  117. loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class)
  118. if loaded_priority_weight_strategy is None:
  119. raise AirflowException(f"Unknown priority strategy {priority_weight_strategy_class}")
  120. return loaded_priority_weight_strategy()