# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations from collections import defaultdict from contextlib import contextmanager from sqlalchemy import text def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]: """ Return the primary and unique constraint along with column name. Some tables like `task_instance` are missing the primary key constraint name and the name is auto-generated by the SQL server, so this function helps to retrieve any primary or unique constraint name. :param conn: sql connection object :param table_name: table name :return: a dictionary of ((constraint name, constraint type), column name) of table """ query = text( f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME WHERE tc.TABLE_NAME = '{table_name}' AND (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE' or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY') """ ) result = conn.execute(query).fetchall() constraint_dict = defaultdict(lambda: defaultdict(list)) for constraint, constraint_type, col_name in result: constraint_dict[constraint_type][constraint].append(col_name) return constraint_dict @contextmanager def disable_sqlite_fkeys(op): if op.get_bind().dialect.name == "sqlite": op.execute("PRAGMA foreign_keys=off") yield op op.execute("PRAGMA foreign_keys=on") else: yield op