utils.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from __future__ import annotations
  18. from collections import defaultdict
  19. from contextlib import contextmanager
  20. from sqlalchemy import text
  21. def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
  22. """
  23. Return the primary and unique constraint along with column name.
  24. Some tables like `task_instance` are missing the primary key constraint
  25. name and the name is auto-generated by the SQL server, so this function
  26. helps to retrieve any primary or unique constraint name.
  27. :param conn: sql connection object
  28. :param table_name: table name
  29. :return: a dictionary of ((constraint name, constraint type), column name) of table
  30. """
  31. query = text(
  32. f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
  33. FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
  34. JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
  35. WHERE tc.TABLE_NAME = '{table_name}' AND
  36. (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE'
  37. or UPPER(tc.CONSTRAINT_TYPE) = 'FOREIGN KEY')
  38. """
  39. )
  40. result = conn.execute(query).fetchall()
  41. constraint_dict = defaultdict(lambda: defaultdict(list))
  42. for constraint, constraint_type, col_name in result:
  43. constraint_dict[constraint_type][constraint].append(col_name)
  44. return constraint_dict
  45. @contextmanager
  46. def disable_sqlite_fkeys(op):
  47. if op.get_bind().dialect.name == "sqlite":
  48. op.execute("PRAGMA foreign_keys=off")
  49. yield op
  50. op.execute("PRAGMA foreign_keys=on")
  51. else:
  52. yield op