123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from __future__ import annotations
- from collections import defaultdict
- from contextlib import contextmanager
- from sqlalchemy import text
- def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
- """
- Return the primary and unique constraint along with column name.
- Some tables like `task_instance` are missing the primary key constraint
- name and the name is auto-generated by the SQL server, so this function
- helps to retrieve any primary or unique constraint name.
- :param conn: sql connection object
- :param table_name: table name
- :return: a dictionary of ((constraint name, constraint type), column name) of table
- """
- query = text(
- f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
- FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
- JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
- WHERE tc.TABLE_NAME = '{table_name}' AND
- (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE'
- or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY')
- """
- )
- result = conn.execute(query).fetchall()
- constraint_dict = defaultdict(lambda: defaultdict(list))
- for constraint, constraint_type, col_name in result:
- constraint_dict[constraint_type][constraint].append(col_name)
- return constraint_dict
- @contextmanager
- def disable_sqlite_fkeys(op):
- if op.get_bind().dialect.name == "sqlite":
- op.execute("PRAGMA foreign_keys=off")
- yield op
- op.execute("PRAGMA foreign_keys=on")
- else:
- yield op
|