123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- #
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- """Local client API."""
- from __future__ import annotations
- from airflow.api.client import api_client
- from airflow.api.common import delete_dag, trigger_dag
- from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api
- from airflow.exceptions import AirflowBadRequest, PoolNotFound
- from airflow.models.pool import Pool
- class Client(api_client.Client):
- """Local API client implementation."""
- def trigger_dag(
- self, dag_id, run_id=None, conf=None, execution_date=None, replace_microseconds=True
- ) -> dict | None:
- dag_run = trigger_dag.trigger_dag(
- dag_id=dag_id,
- run_id=run_id,
- conf=conf,
- execution_date=execution_date,
- replace_microseconds=replace_microseconds,
- )
- if dag_run:
- return {
- "conf": dag_run.conf,
- "dag_id": dag_run.dag_id,
- "dag_run_id": dag_run.run_id,
- "data_interval_start": dag_run.data_interval_start,
- "data_interval_end": dag_run.data_interval_end,
- "end_date": dag_run.end_date,
- "external_trigger": dag_run.external_trigger,
- "last_scheduling_decision": dag_run.last_scheduling_decision,
- "logical_date": dag_run.logical_date,
- "run_type": dag_run.run_type,
- "start_date": dag_run.start_date,
- "state": dag_run.state,
- }
- return dag_run
- def delete_dag(self, dag_id):
- count = delete_dag.delete_dag(dag_id)
- return f"Removed {count} record(s)"
- def get_pool(self, name):
- pool = Pool.get_pool(pool_name=name)
- if not pool:
- raise PoolNotFound(f"Pool {name} not found")
- return pool.pool, pool.slots, pool.description, pool.include_deferred
- def get_pools(self):
- return [(p.pool, p.slots, p.description, p.include_deferred) for p in Pool.get_pools()]
- def create_pool(self, name, slots, description, include_deferred):
- if not (name and name.strip()):
- raise AirflowBadRequest("Pool name shouldn't be empty")
- pool_name_length = Pool.pool.property.columns[0].type.length
- if len(name) > pool_name_length:
- raise AirflowBadRequest(f"pool name cannot be more than {pool_name_length} characters")
- try:
- slots = int(slots)
- except ValueError:
- raise AirflowBadRequest(f"Bad value for `slots`: {slots}")
- pool = Pool.create_or_update_pool(
- name=name, slots=slots, description=description, include_deferred=include_deferred
- )
- return pool.pool, pool.slots, pool.description
- def delete_pool(self, name):
- pool = Pool.delete_pool(name=name)
- return pool.pool, pool.slots, pool.description
- def get_lineage(self, dag_id, execution_date):
- lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date)
- return lineage
|