schemaobj.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
  2. # mypy: no-warn-return-any, allow-any-generics
  3. from __future__ import annotations
  4. from typing import Any
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Tuple
  10. from typing import TYPE_CHECKING
  11. from typing import Union
  12. from sqlalchemy import schema as sa_schema
  13. from sqlalchemy.sql.schema import Column
  14. from sqlalchemy.sql.schema import Constraint
  15. from sqlalchemy.sql.schema import Index
  16. from sqlalchemy.types import Integer
  17. from sqlalchemy.types import NULLTYPE
  18. from .. import util
  19. from ..util import sqla_compat
  20. if TYPE_CHECKING:
  21. from sqlalchemy.sql.elements import ColumnElement
  22. from sqlalchemy.sql.elements import TextClause
  23. from sqlalchemy.sql.schema import CheckConstraint
  24. from sqlalchemy.sql.schema import ForeignKey
  25. from sqlalchemy.sql.schema import ForeignKeyConstraint
  26. from sqlalchemy.sql.schema import MetaData
  27. from sqlalchemy.sql.schema import PrimaryKeyConstraint
  28. from sqlalchemy.sql.schema import Table
  29. from sqlalchemy.sql.schema import UniqueConstraint
  30. from sqlalchemy.sql.type_api import TypeEngine
  31. from ..runtime.migration import MigrationContext
  32. class SchemaObjects:
  33. def __init__(
  34. self, migration_context: Optional[MigrationContext] = None
  35. ) -> None:
  36. self.migration_context = migration_context
  37. def primary_key_constraint(
  38. self,
  39. name: Optional[sqla_compat._ConstraintNameDefined],
  40. table_name: str,
  41. cols: Sequence[str],
  42. schema: Optional[str] = None,
  43. **dialect_kw,
  44. ) -> PrimaryKeyConstraint:
  45. m = self.metadata()
  46. columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
  47. t = sa_schema.Table(table_name, m, *columns, schema=schema)
  48. # SQLAlchemy primary key constraint name arg is wrongly typed on
  49. # the SQLAlchemy side through 2.0.5 at least
  50. p = sa_schema.PrimaryKeyConstraint(
  51. *[t.c[n] for n in cols], name=name, **dialect_kw # type: ignore
  52. )
  53. return p
  54. def foreign_key_constraint(
  55. self,
  56. name: Optional[sqla_compat._ConstraintNameDefined],
  57. source: str,
  58. referent: str,
  59. local_cols: List[str],
  60. remote_cols: List[str],
  61. onupdate: Optional[str] = None,
  62. ondelete: Optional[str] = None,
  63. deferrable: Optional[bool] = None,
  64. source_schema: Optional[str] = None,
  65. referent_schema: Optional[str] = None,
  66. initially: Optional[str] = None,
  67. match: Optional[str] = None,
  68. **dialect_kw,
  69. ) -> ForeignKeyConstraint:
  70. m = self.metadata()
  71. if source == referent and source_schema == referent_schema:
  72. t1_cols = local_cols + remote_cols
  73. else:
  74. t1_cols = local_cols
  75. sa_schema.Table(
  76. referent,
  77. m,
  78. *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
  79. schema=referent_schema,
  80. )
  81. t1 = sa_schema.Table(
  82. source,
  83. m,
  84. *[
  85. sa_schema.Column(n, NULLTYPE)
  86. for n in util.unique_list(t1_cols)
  87. ],
  88. schema=source_schema,
  89. )
  90. tname = (
  91. "%s.%s" % (referent_schema, referent)
  92. if referent_schema
  93. else referent
  94. )
  95. dialect_kw["match"] = match
  96. f = sa_schema.ForeignKeyConstraint(
  97. local_cols,
  98. ["%s.%s" % (tname, n) for n in remote_cols],
  99. name=name,
  100. onupdate=onupdate,
  101. ondelete=ondelete,
  102. deferrable=deferrable,
  103. initially=initially,
  104. **dialect_kw,
  105. )
  106. t1.append_constraint(f)
  107. return f
  108. def unique_constraint(
  109. self,
  110. name: Optional[sqla_compat._ConstraintNameDefined],
  111. source: str,
  112. local_cols: Sequence[str],
  113. schema: Optional[str] = None,
  114. **kw,
  115. ) -> UniqueConstraint:
  116. t = sa_schema.Table(
  117. source,
  118. self.metadata(),
  119. *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
  120. schema=schema,
  121. )
  122. kw["name"] = name
  123. uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
  124. # TODO: need event tests to ensure the event
  125. # is fired off here
  126. t.append_constraint(uq)
  127. return uq
  128. def check_constraint(
  129. self,
  130. name: Optional[sqla_compat._ConstraintNameDefined],
  131. source: str,
  132. condition: Union[str, TextClause, ColumnElement[Any]],
  133. schema: Optional[str] = None,
  134. **kw,
  135. ) -> Union[CheckConstraint]:
  136. t = sa_schema.Table(
  137. source,
  138. self.metadata(),
  139. sa_schema.Column("x", Integer),
  140. schema=schema,
  141. )
  142. ck = sa_schema.CheckConstraint(condition, name=name, **kw)
  143. t.append_constraint(ck)
  144. return ck
  145. def generic_constraint(
  146. self,
  147. name: Optional[sqla_compat._ConstraintNameDefined],
  148. table_name: str,
  149. type_: Optional[str],
  150. schema: Optional[str] = None,
  151. **kw,
  152. ) -> Any:
  153. t = self.table(table_name, schema=schema)
  154. types: Dict[Optional[str], Any] = {
  155. "foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
  156. [], [], name=name
  157. ),
  158. "primary": sa_schema.PrimaryKeyConstraint,
  159. "unique": sa_schema.UniqueConstraint,
  160. "check": lambda name: sa_schema.CheckConstraint("", name=name),
  161. None: sa_schema.Constraint,
  162. }
  163. try:
  164. const = types[type_]
  165. except KeyError as ke:
  166. raise TypeError(
  167. "'type' can be one of %s"
  168. % ", ".join(sorted(repr(x) for x in types))
  169. ) from ke
  170. else:
  171. const = const(name=name)
  172. t.append_constraint(const)
  173. return const
  174. def metadata(self) -> MetaData:
  175. kw = {}
  176. if (
  177. self.migration_context is not None
  178. and "target_metadata" in self.migration_context.opts
  179. ):
  180. mt = self.migration_context.opts["target_metadata"]
  181. if hasattr(mt, "naming_convention"):
  182. kw["naming_convention"] = mt.naming_convention
  183. return sa_schema.MetaData(**kw)
  184. def table(self, name: str, *columns, **kw) -> Table:
  185. m = self.metadata()
  186. cols = [
  187. sqla_compat._copy(c) if c.table is not None else c
  188. for c in columns
  189. if isinstance(c, Column)
  190. ]
  191. # these flags have already added their UniqueConstraint /
  192. # Index objects to the table, so flip them off here.
  193. # SQLAlchemy tometadata() avoids this instead by preserving the
  194. # flags and skipping the constraints that have _type_bound on them,
  195. # but for a migration we'd rather list out the constraints
  196. # explicitly.
  197. _constraints_included = kw.pop("_constraints_included", False)
  198. if _constraints_included:
  199. for c in cols:
  200. c.unique = c.index = False
  201. t = sa_schema.Table(name, m, *cols, **kw)
  202. constraints = [
  203. (
  204. sqla_compat._copy(elem, target_table=t)
  205. if getattr(elem, "parent", None) is not t
  206. and getattr(elem, "parent", None) is not None
  207. else elem
  208. )
  209. for elem in columns
  210. if isinstance(elem, (Constraint, Index))
  211. ]
  212. for const in constraints:
  213. t.append_constraint(const)
  214. for f in t.foreign_keys:
  215. self._ensure_table_for_fk(m, f)
  216. return t
  217. def column(self, name: str, type_: TypeEngine, **kw) -> Column:
  218. return sa_schema.Column(name, type_, **kw)
  219. def index(
  220. self,
  221. name: Optional[str],
  222. tablename: Optional[str],
  223. columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
  224. schema: Optional[str] = None,
  225. **kw,
  226. ) -> Index:
  227. t = sa_schema.Table(
  228. tablename or "no_table",
  229. self.metadata(),
  230. schema=schema,
  231. )
  232. kw["_table"] = t
  233. idx = sa_schema.Index(
  234. name,
  235. *[util.sqla_compat._textual_index_column(t, n) for n in columns],
  236. **kw,
  237. )
  238. return idx
  239. def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
  240. if "." in table_key:
  241. tokens = table_key.split(".")
  242. sname: Optional[str] = ".".join(tokens[0:-1])
  243. tname = tokens[-1]
  244. else:
  245. tname = table_key
  246. sname = None
  247. return (sname, tname)
  248. def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None:
  249. """create a placeholder Table object for the referent of a
  250. ForeignKey.
  251. """
  252. if isinstance(fk._colspec, str):
  253. table_key, cname = fk._colspec.rsplit(".", 1)
  254. sname, tname = self._parse_table_key(table_key)
  255. if table_key not in metadata.tables:
  256. rel_t = sa_schema.Table(tname, metadata, schema=sname)
  257. else:
  258. rel_t = metadata.tables[table_key]
  259. if cname not in rel_t.c:
  260. rel_t.append_column(sa_schema.Column(cname, NULLTYPE))