local_client.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. """Local client API."""
  19. from __future__ import annotations
  20. from airflow.api.client import api_client
  21. from airflow.api.common import delete_dag, trigger_dag
  22. from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api
  23. from airflow.exceptions import AirflowBadRequest, PoolNotFound
  24. from airflow.models.pool import Pool
  25. class Client(api_client.Client):
  26. """Local API client implementation."""
  27. def trigger_dag(
  28. self, dag_id, run_id=None, conf=None, execution_date=None, replace_microseconds=True
  29. ) -> dict | None:
  30. dag_run = trigger_dag.trigger_dag(
  31. dag_id=dag_id,
  32. run_id=run_id,
  33. conf=conf,
  34. execution_date=execution_date,
  35. replace_microseconds=replace_microseconds,
  36. )
  37. if dag_run:
  38. return {
  39. "conf": dag_run.conf,
  40. "dag_id": dag_run.dag_id,
  41. "dag_run_id": dag_run.run_id,
  42. "data_interval_start": dag_run.data_interval_start,
  43. "data_interval_end": dag_run.data_interval_end,
  44. "end_date": dag_run.end_date,
  45. "external_trigger": dag_run.external_trigger,
  46. "last_scheduling_decision": dag_run.last_scheduling_decision,
  47. "logical_date": dag_run.logical_date,
  48. "run_type": dag_run.run_type,
  49. "start_date": dag_run.start_date,
  50. "state": dag_run.state,
  51. }
  52. return dag_run
  53. def delete_dag(self, dag_id):
  54. count = delete_dag.delete_dag(dag_id)
  55. return f"Removed {count} record(s)"
  56. def get_pool(self, name):
  57. pool = Pool.get_pool(pool_name=name)
  58. if not pool:
  59. raise PoolNotFound(f"Pool {name} not found")
  60. return pool.pool, pool.slots, pool.description, pool.include_deferred
  61. def get_pools(self):
  62. return [(p.pool, p.slots, p.description, p.include_deferred) for p in Pool.get_pools()]
  63. def create_pool(self, name, slots, description, include_deferred):
  64. if not (name and name.strip()):
  65. raise AirflowBadRequest("Pool name shouldn't be empty")
  66. pool_name_length = Pool.pool.property.columns[0].type.length
  67. if len(name) > pool_name_length:
  68. raise AirflowBadRequest(f"pool name cannot be more than {pool_name_length} characters")
  69. try:
  70. slots = int(slots)
  71. except ValueError:
  72. raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
  73. pool = Pool.create_or_update_pool(
  74. name=name, slots=slots, description=description, include_deferred=include_deferred
  75. )
  76. return pool.pool, pool.slots, pool.description
  77. def delete_pool(self, name):
  78. pool = Pool.delete_pool(name=name)
  79. return pool.pool, pool.slots, pool.description
  80. def get_lineage(self, dag_id, execution_date):
  81. lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date)
  82. return lineage