util.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126
  1. # sql/util.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """High level utilities which build upon other modules here.
  8. """
  9. from collections import deque
  10. from itertools import chain
  11. from . import coercions
  12. from . import operators
  13. from . import roles
  14. from . import visitors
  15. from .annotation import _deep_annotate # noqa
  16. from .annotation import _deep_deannotate # noqa
  17. from .annotation import _shallow_annotate # noqa
  18. from .base import _expand_cloned
  19. from .base import _from_objects
  20. from .base import ColumnSet
  21. from .ddl import sort_tables # noqa
  22. from .elements import _find_columns # noqa
  23. from .elements import _label_reference
  24. from .elements import _textual_label_reference
  25. from .elements import BindParameter
  26. from .elements import ColumnClause
  27. from .elements import ColumnElement
  28. from .elements import Grouping
  29. from .elements import Label
  30. from .elements import Null
  31. from .elements import UnaryExpression
  32. from .schema import Column
  33. from .selectable import Alias
  34. from .selectable import FromClause
  35. from .selectable import FromGrouping
  36. from .selectable import Join
  37. from .selectable import ScalarSelect
  38. from .selectable import SelectBase
  39. from .selectable import TableClause
  40. from .traversals import HasCacheKey # noqa
  41. from .. import exc
  42. from .. import util
  43. join_condition = util.langhelpers.public_factory(
  44. Join._join_condition, ".sql.util.join_condition"
  45. )
  46. def find_join_source(clauses, join_to):
  47. """Given a list of FROM clauses and a selectable,
  48. return the first index and element from the list of
  49. clauses which can be joined against the selectable. returns
  50. None, None if no match is found.
  51. e.g.::
  52. clause1 = table1.join(table2)
  53. clause2 = table4.join(table5)
  54. join_to = table2.join(table3)
  55. find_join_source([clause1, clause2], join_to) == clause1
  56. """
  57. selectables = list(_from_objects(join_to))
  58. idx = []
  59. for i, f in enumerate(clauses):
  60. for s in selectables:
  61. if f.is_derived_from(s):
  62. idx.append(i)
  63. return idx
  64. def find_left_clause_that_matches_given(clauses, join_from):
  65. """Given a list of FROM clauses and a selectable,
  66. return the indexes from the list of
  67. clauses which is derived from the selectable.
  68. """
  69. selectables = list(_from_objects(join_from))
  70. liberal_idx = []
  71. for i, f in enumerate(clauses):
  72. for s in selectables:
  73. # basic check, if f is derived from s.
  74. # this can be joins containing a table, or an aliased table
  75. # or select statement matching to a table. This check
  76. # will match a table to a selectable that is adapted from
  77. # that table. With Query, this suits the case where a join
  78. # is being made to an adapted entity
  79. if f.is_derived_from(s):
  80. liberal_idx.append(i)
  81. break
  82. # in an extremely small set of use cases, a join is being made where
  83. # there are multiple FROM clauses where our target table is represented
  84. # in more than one, such as embedded or similar. in this case, do
  85. # another pass where we try to get a more exact match where we aren't
  86. # looking at adaption relationships.
  87. if len(liberal_idx) > 1:
  88. conservative_idx = []
  89. for idx in liberal_idx:
  90. f = clauses[idx]
  91. for s in selectables:
  92. if set(surface_selectables(f)).intersection(
  93. surface_selectables(s)
  94. ):
  95. conservative_idx.append(idx)
  96. break
  97. if conservative_idx:
  98. return conservative_idx
  99. return liberal_idx
  100. def find_left_clause_to_join_from(clauses, join_to, onclause):
  101. """Given a list of FROM clauses, a selectable,
  102. and optional ON clause, return a list of integer indexes from the
  103. clauses list indicating the clauses that can be joined from.
  104. The presence of an "onclause" indicates that at least one clause can
  105. definitely be joined from; if the list of clauses is of length one
  106. and the onclause is given, returns that index. If the list of clauses
  107. is more than length one, and the onclause is given, attempts to locate
  108. which clauses contain the same columns.
  109. """
  110. idx = []
  111. selectables = set(_from_objects(join_to))
  112. # if we are given more than one target clause to join
  113. # from, use the onclause to provide a more specific answer.
  114. # otherwise, don't try to limit, after all, "ON TRUE" is a valid
  115. # on clause
  116. if len(clauses) > 1 and onclause is not None:
  117. resolve_ambiguity = True
  118. cols_in_onclause = _find_columns(onclause)
  119. else:
  120. resolve_ambiguity = False
  121. cols_in_onclause = None
  122. for i, f in enumerate(clauses):
  123. for s in selectables.difference([f]):
  124. if resolve_ambiguity:
  125. if set(f.c).union(s.c).issuperset(cols_in_onclause):
  126. idx.append(i)
  127. break
  128. elif onclause is not None or Join._can_join(f, s):
  129. idx.append(i)
  130. break
  131. if len(idx) > 1:
  132. # this is the same "hide froms" logic from
  133. # Selectable._get_display_froms
  134. toremove = set(
  135. chain(*[_expand_cloned(f._hide_froms) for f in clauses])
  136. )
  137. idx = [i for i in idx if clauses[i] not in toremove]
  138. # onclause was given and none of them resolved, so assume
  139. # all indexes can match
  140. if not idx and onclause is not None:
  141. return range(len(clauses))
  142. else:
  143. return idx
  144. def visit_binary_product(fn, expr):
  145. """Produce a traversal of the given expression, delivering
  146. column comparisons to the given function.
  147. The function is of the form::
  148. def my_fn(binary, left, right)
  149. For each binary expression located which has a
  150. comparison operator, the product of "left" and
  151. "right" will be delivered to that function,
  152. in terms of that binary.
  153. Hence an expression like::
  154. and_(
  155. (a + b) == q + func.sum(e + f),
  156. j == r
  157. )
  158. would have the traversal::
  159. a <eq> q
  160. a <eq> e
  161. a <eq> f
  162. b <eq> q
  163. b <eq> e
  164. b <eq> f
  165. j <eq> r
  166. That is, every combination of "left" and
  167. "right" that doesn't further contain
  168. a binary comparison is passed as pairs.
  169. """
  170. stack = []
  171. def visit(element):
  172. if isinstance(element, ScalarSelect):
  173. # we don't want to dig into correlated subqueries,
  174. # those are just column elements by themselves
  175. yield element
  176. elif element.__visit_name__ == "binary" and operators.is_comparison(
  177. element.operator
  178. ):
  179. stack.insert(0, element)
  180. for l in visit(element.left):
  181. for r in visit(element.right):
  182. fn(stack[0], l, r)
  183. stack.pop(0)
  184. for elem in element.get_children():
  185. visit(elem)
  186. else:
  187. if isinstance(element, ColumnClause):
  188. yield element
  189. for elem in element.get_children():
  190. for e in visit(elem):
  191. yield e
  192. list(visit(expr))
  193. visit = None # remove gc cycles
  194. def find_tables(
  195. clause,
  196. check_columns=False,
  197. include_aliases=False,
  198. include_joins=False,
  199. include_selects=False,
  200. include_crud=False,
  201. ):
  202. """locate Table objects within the given expression."""
  203. tables = []
  204. _visitors = {}
  205. if include_selects:
  206. _visitors["select"] = _visitors["compound_select"] = tables.append
  207. if include_joins:
  208. _visitors["join"] = tables.append
  209. if include_aliases:
  210. _visitors["alias"] = _visitors["subquery"] = _visitors[
  211. "tablesample"
  212. ] = _visitors["lateral"] = tables.append
  213. if include_crud:
  214. _visitors["insert"] = _visitors["update"] = _visitors[
  215. "delete"
  216. ] = lambda ent: tables.append(ent.table)
  217. if check_columns:
  218. def visit_column(column):
  219. tables.append(column.table)
  220. _visitors["column"] = visit_column
  221. _visitors["table"] = tables.append
  222. visitors.traverse(clause, {}, _visitors)
  223. return tables
  224. def unwrap_order_by(clause):
  225. """Break up an 'order by' expression into individual column-expressions,
  226. without DESC/ASC/NULLS FIRST/NULLS LAST"""
  227. cols = util.column_set()
  228. result = []
  229. stack = deque([clause])
  230. # examples
  231. # column -> ASC/DESC == column
  232. # column -> ASC/DESC -> label == column
  233. # column -> label -> ASC/DESC -> label == column
  234. # scalar_select -> label -> ASC/DESC == scalar_select -> label
  235. while stack:
  236. t = stack.popleft()
  237. if isinstance(t, ColumnElement) and (
  238. not isinstance(t, UnaryExpression)
  239. or not operators.is_ordering_modifier(t.modifier)
  240. ):
  241. if isinstance(t, Label) and not isinstance(
  242. t.element, ScalarSelect
  243. ):
  244. t = t.element
  245. if isinstance(t, Grouping):
  246. t = t.element
  247. stack.append(t)
  248. continue
  249. elif isinstance(t, _label_reference):
  250. t = t.element
  251. stack.append(t)
  252. continue
  253. if isinstance(t, (_textual_label_reference)):
  254. continue
  255. if t not in cols:
  256. cols.add(t)
  257. result.append(t)
  258. else:
  259. for c in t.get_children():
  260. stack.append(c)
  261. return result
  262. def unwrap_label_reference(element):
  263. def replace(elem):
  264. if isinstance(elem, (_label_reference, _textual_label_reference)):
  265. return elem.element
  266. return visitors.replacement_traverse(element, {}, replace)
  267. def expand_column_list_from_order_by(collist, order_by):
  268. """Given the columns clause and ORDER BY of a selectable,
  269. return a list of column expressions that can be added to the collist
  270. corresponding to the ORDER BY, without repeating those already
  271. in the collist.
  272. """
  273. cols_already_present = set(
  274. [
  275. col.element if col._order_by_label_element is not None else col
  276. for col in collist
  277. ]
  278. )
  279. to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
  280. return [col for col in to_look_for if col not in cols_already_present]
  281. def clause_is_present(clause, search):
  282. """Given a target clause and a second to search within, return True
  283. if the target is plainly present in the search without any
  284. subqueries or aliases involved.
  285. Basically descends through Joins.
  286. """
  287. for elem in surface_selectables(search):
  288. if clause == elem: # use == here so that Annotated's compare
  289. return True
  290. else:
  291. return False
  292. def tables_from_leftmost(clause):
  293. if isinstance(clause, Join):
  294. for t in tables_from_leftmost(clause.left):
  295. yield t
  296. for t in tables_from_leftmost(clause.right):
  297. yield t
  298. elif isinstance(clause, FromGrouping):
  299. for t in tables_from_leftmost(clause.element):
  300. yield t
  301. else:
  302. yield clause
  303. def surface_selectables(clause):
  304. stack = [clause]
  305. while stack:
  306. elem = stack.pop()
  307. yield elem
  308. if isinstance(elem, Join):
  309. stack.extend((elem.left, elem.right))
  310. elif isinstance(elem, FromGrouping):
  311. stack.append(elem.element)
  312. def surface_selectables_only(clause):
  313. stack = [clause]
  314. while stack:
  315. elem = stack.pop()
  316. if isinstance(elem, (TableClause, Alias)):
  317. yield elem
  318. if isinstance(elem, Join):
  319. stack.extend((elem.left, elem.right))
  320. elif isinstance(elem, FromGrouping):
  321. stack.append(elem.element)
  322. elif isinstance(elem, ColumnClause):
  323. if elem.table is not None:
  324. stack.append(elem.table)
  325. else:
  326. yield elem
  327. elif elem is not None:
  328. yield elem
  329. def extract_first_column_annotation(column, annotation_name):
  330. filter_ = (FromGrouping, SelectBase)
  331. stack = deque([column])
  332. while stack:
  333. elem = stack.popleft()
  334. if annotation_name in elem._annotations:
  335. return elem._annotations[annotation_name]
  336. for sub in elem.get_children():
  337. if isinstance(sub, filter_):
  338. continue
  339. stack.append(sub)
  340. return None
  341. def selectables_overlap(left, right):
  342. """Return True if left/right have some overlapping selectable"""
  343. return bool(
  344. set(surface_selectables(left)).intersection(surface_selectables(right))
  345. )
  346. def bind_values(clause):
  347. """Return an ordered list of "bound" values in the given clause.
  348. E.g.::
  349. >>> expr = and_(
  350. ... table.c.foo==5, table.c.foo==7
  351. ... )
  352. >>> bind_values(expr)
  353. [5, 7]
  354. """
  355. v = []
  356. def visit_bindparam(bind):
  357. v.append(bind.effective_value)
  358. visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
  359. return v
  360. def _quote_ddl_expr(element):
  361. if isinstance(element, util.string_types):
  362. element = element.replace("'", "''")
  363. return "'%s'" % element
  364. else:
  365. return repr(element)
  366. class _repr_base(object):
  367. _LIST = 0
  368. _TUPLE = 1
  369. _DICT = 2
  370. __slots__ = ("max_chars",)
  371. def trunc(self, value):
  372. rep = repr(value)
  373. lenrep = len(rep)
  374. if lenrep > self.max_chars:
  375. segment_length = self.max_chars // 2
  376. rep = (
  377. rep[0:segment_length]
  378. + (
  379. " ... (%d characters truncated) ... "
  380. % (lenrep - self.max_chars)
  381. )
  382. + rep[-segment_length:]
  383. )
  384. return rep
  385. def _repr_single_value(value):
  386. rp = _repr_base()
  387. rp.max_chars = 300
  388. return rp.trunc(value)
  389. class _repr_row(_repr_base):
  390. """Provide a string view of a row."""
  391. __slots__ = ("row",)
  392. def __init__(self, row, max_chars=300):
  393. self.row = row
  394. self.max_chars = max_chars
  395. def __repr__(self):
  396. trunc = self.trunc
  397. return "(%s%s)" % (
  398. ", ".join(trunc(value) for value in self.row),
  399. "," if len(self.row) == 1 else "",
  400. )
  401. class _repr_params(_repr_base):
  402. """Provide a string view of bound parameters.
  403. Truncates display to a given number of 'multi' parameter sets,
  404. as well as long values to a given number of characters.
  405. """
  406. __slots__ = "params", "batches", "ismulti"
  407. def __init__(self, params, batches, max_chars=300, ismulti=None):
  408. self.params = params
  409. self.ismulti = ismulti
  410. self.batches = batches
  411. self.max_chars = max_chars
  412. def __repr__(self):
  413. if self.ismulti is None:
  414. return self.trunc(self.params)
  415. if isinstance(self.params, list):
  416. typ = self._LIST
  417. elif isinstance(self.params, tuple):
  418. typ = self._TUPLE
  419. elif isinstance(self.params, dict):
  420. typ = self._DICT
  421. else:
  422. return self.trunc(self.params)
  423. if self.ismulti and len(self.params) > self.batches:
  424. msg = " ... displaying %i of %i total bound parameter sets ... "
  425. return " ".join(
  426. (
  427. self._repr_multi(self.params[: self.batches - 2], typ)[
  428. 0:-1
  429. ],
  430. msg % (self.batches, len(self.params)),
  431. self._repr_multi(self.params[-2:], typ)[1:],
  432. )
  433. )
  434. elif self.ismulti:
  435. return self._repr_multi(self.params, typ)
  436. else:
  437. return self._repr_params(self.params, typ)
  438. def _repr_multi(self, multi_params, typ):
  439. if multi_params:
  440. if isinstance(multi_params[0], list):
  441. elem_type = self._LIST
  442. elif isinstance(multi_params[0], tuple):
  443. elem_type = self._TUPLE
  444. elif isinstance(multi_params[0], dict):
  445. elem_type = self._DICT
  446. else:
  447. assert False, "Unknown parameter type %s" % (
  448. type(multi_params[0])
  449. )
  450. elements = ", ".join(
  451. self._repr_params(params, elem_type) for params in multi_params
  452. )
  453. else:
  454. elements = ""
  455. if typ == self._LIST:
  456. return "[%s]" % elements
  457. else:
  458. return "(%s)" % elements
  459. def _repr_params(self, params, typ):
  460. trunc = self.trunc
  461. if typ is self._DICT:
  462. return "{%s}" % (
  463. ", ".join(
  464. "%r: %s" % (key, trunc(value))
  465. for key, value in params.items()
  466. )
  467. )
  468. elif typ is self._TUPLE:
  469. return "(%s%s)" % (
  470. ", ".join(trunc(value) for value in params),
  471. "," if len(params) == 1 else "",
  472. )
  473. else:
  474. return "[%s]" % (", ".join(trunc(value) for value in params))
  475. def adapt_criterion_to_null(crit, nulls):
  476. """given criterion containing bind params, convert selected elements
  477. to IS NULL.
  478. """
  479. def visit_binary(binary):
  480. if (
  481. isinstance(binary.left, BindParameter)
  482. and binary.left._identifying_key in nulls
  483. ):
  484. # reverse order if the NULL is on the left side
  485. binary.left = binary.right
  486. binary.right = Null()
  487. binary.operator = operators.is_
  488. binary.negate = operators.is_not
  489. elif (
  490. isinstance(binary.right, BindParameter)
  491. and binary.right._identifying_key in nulls
  492. ):
  493. binary.right = Null()
  494. binary.operator = operators.is_
  495. binary.negate = operators.is_not
  496. return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
  497. def splice_joins(left, right, stop_on=None):
  498. if left is None:
  499. return right
  500. stack = [(right, None)]
  501. adapter = ClauseAdapter(left)
  502. ret = None
  503. while stack:
  504. (right, prevright) = stack.pop()
  505. if isinstance(right, Join) and right is not stop_on:
  506. right = right._clone()
  507. right.onclause = adapter.traverse(right.onclause)
  508. stack.append((right.left, right))
  509. else:
  510. right = adapter.traverse(right)
  511. if prevright is not None:
  512. prevright.left = right
  513. if ret is None:
  514. ret = right
  515. return ret
  516. def reduce_columns(columns, *clauses, **kw):
  517. r"""given a list of columns, return a 'reduced' set based on natural
  518. equivalents.
  519. the set is reduced to the smallest list of columns which have no natural
  520. equivalent present in the list. A "natural equivalent" means that two
  521. columns will ultimately represent the same value because they are related
  522. by a foreign key.
  523. \*clauses is an optional list of join clauses which will be traversed
  524. to further identify columns that are "equivalent".
  525. \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
  526. whose tables are not yet configured, or columns that aren't yet present.
  527. This function is primarily used to determine the most minimal "primary
  528. key" from a selectable, by reducing the set of primary key columns present
  529. in the selectable to just those that are not repeated.
  530. """
  531. ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
  532. only_synonyms = kw.pop("only_synonyms", False)
  533. columns = util.ordered_column_set(columns)
  534. omit = util.column_set()
  535. for col in columns:
  536. for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
  537. for c in columns:
  538. if c is col:
  539. continue
  540. try:
  541. fk_col = fk.column
  542. except exc.NoReferencedColumnError:
  543. # TODO: add specific coverage here
  544. # to test/sql/test_selectable ReduceTest
  545. if ignore_nonexistent_tables:
  546. continue
  547. else:
  548. raise
  549. except exc.NoReferencedTableError:
  550. # TODO: add specific coverage here
  551. # to test/sql/test_selectable ReduceTest
  552. if ignore_nonexistent_tables:
  553. continue
  554. else:
  555. raise
  556. if fk_col.shares_lineage(c) and (
  557. not only_synonyms or c.name == col.name
  558. ):
  559. omit.add(col)
  560. break
  561. if clauses:
  562. def visit_binary(binary):
  563. if binary.operator == operators.eq:
  564. cols = util.column_set(
  565. chain(*[c.proxy_set for c in columns.difference(omit)])
  566. )
  567. if binary.left in cols and binary.right in cols:
  568. for c in reversed(columns):
  569. if c.shares_lineage(binary.right) and (
  570. not only_synonyms or c.name == binary.left.name
  571. ):
  572. omit.add(c)
  573. break
  574. for clause in clauses:
  575. if clause is not None:
  576. visitors.traverse(clause, {}, {"binary": visit_binary})
  577. return ColumnSet(columns.difference(omit))
  578. def criterion_as_pairs(
  579. expression,
  580. consider_as_foreign_keys=None,
  581. consider_as_referenced_keys=None,
  582. any_operator=False,
  583. ):
  584. """traverse an expression and locate binary criterion pairs."""
  585. if consider_as_foreign_keys and consider_as_referenced_keys:
  586. raise exc.ArgumentError(
  587. "Can only specify one of "
  588. "'consider_as_foreign_keys' or "
  589. "'consider_as_referenced_keys'"
  590. )
  591. def col_is(a, b):
  592. # return a is b
  593. return a.compare(b)
  594. def visit_binary(binary):
  595. if not any_operator and binary.operator is not operators.eq:
  596. return
  597. if not isinstance(binary.left, ColumnElement) or not isinstance(
  598. binary.right, ColumnElement
  599. ):
  600. return
  601. if consider_as_foreign_keys:
  602. if binary.left in consider_as_foreign_keys and (
  603. col_is(binary.right, binary.left)
  604. or binary.right not in consider_as_foreign_keys
  605. ):
  606. pairs.append((binary.right, binary.left))
  607. elif binary.right in consider_as_foreign_keys and (
  608. col_is(binary.left, binary.right)
  609. or binary.left not in consider_as_foreign_keys
  610. ):
  611. pairs.append((binary.left, binary.right))
  612. elif consider_as_referenced_keys:
  613. if binary.left in consider_as_referenced_keys and (
  614. col_is(binary.right, binary.left)
  615. or binary.right not in consider_as_referenced_keys
  616. ):
  617. pairs.append((binary.left, binary.right))
  618. elif binary.right in consider_as_referenced_keys and (
  619. col_is(binary.left, binary.right)
  620. or binary.left not in consider_as_referenced_keys
  621. ):
  622. pairs.append((binary.right, binary.left))
  623. else:
  624. if isinstance(binary.left, Column) and isinstance(
  625. binary.right, Column
  626. ):
  627. if binary.left.references(binary.right):
  628. pairs.append((binary.right, binary.left))
  629. elif binary.right.references(binary.left):
  630. pairs.append((binary.left, binary.right))
  631. pairs = []
  632. visitors.traverse(expression, {}, {"binary": visit_binary})
  633. return pairs
  634. class ClauseAdapter(visitors.ReplacingExternalTraversal):
  635. """Clones and modifies clauses based on column correspondence.
  636. E.g.::
  637. table1 = Table('sometable', metadata,
  638. Column('col1', Integer),
  639. Column('col2', Integer)
  640. )
  641. table2 = Table('someothertable', metadata,
  642. Column('col1', Integer),
  643. Column('col2', Integer)
  644. )
  645. condition = table1.c.col1 == table2.c.col1
  646. make an alias of table1::
  647. s = table1.alias('foo')
  648. calling ``ClauseAdapter(s).traverse(condition)`` converts
  649. condition to read::
  650. s.c.col1 == table2.c.col1
  651. """
  652. def __init__(
  653. self,
  654. selectable,
  655. equivalents=None,
  656. include_fn=None,
  657. exclude_fn=None,
  658. adapt_on_names=False,
  659. anonymize_labels=False,
  660. adapt_from_selectables=None,
  661. ):
  662. self.__traverse_options__ = {
  663. "stop_on": [selectable],
  664. "anonymize_labels": anonymize_labels,
  665. }
  666. self.selectable = selectable
  667. self.include_fn = include_fn
  668. self.exclude_fn = exclude_fn
  669. self.equivalents = util.column_dict(equivalents or {})
  670. self.adapt_on_names = adapt_on_names
  671. self.adapt_from_selectables = adapt_from_selectables
  672. def _corresponding_column(
  673. self, col, require_embedded, _seen=util.EMPTY_SET
  674. ):
  675. newcol = self.selectable.corresponding_column(
  676. col, require_embedded=require_embedded
  677. )
  678. if newcol is None and col in self.equivalents and col not in _seen:
  679. for equiv in self.equivalents[col]:
  680. newcol = self._corresponding_column(
  681. equiv,
  682. require_embedded=require_embedded,
  683. _seen=_seen.union([col]),
  684. )
  685. if newcol is not None:
  686. return newcol
  687. if self.adapt_on_names and newcol is None:
  688. newcol = self.selectable.exported_columns.get(col.name)
  689. return newcol
  690. @util.preload_module("sqlalchemy.sql.functions")
  691. def replace(self, col, _include_singleton_constants=False):
  692. functions = util.preloaded.sql_functions
  693. if isinstance(col, FromClause) and not isinstance(
  694. col, functions.FunctionElement
  695. ):
  696. if self.selectable.is_derived_from(col):
  697. if self.adapt_from_selectables:
  698. for adp in self.adapt_from_selectables:
  699. if adp.is_derived_from(col):
  700. break
  701. else:
  702. return None
  703. return self.selectable
  704. elif isinstance(col, Alias) and isinstance(
  705. col.element, TableClause
  706. ):
  707. # we are a SELECT statement and not derived from an alias of a
  708. # table (which nonetheless may be a table our SELECT derives
  709. # from), so return the alias to prevent further traversal
  710. # or
  711. # we are an alias of a table and we are not derived from an
  712. # alias of a table (which nonetheless may be the same table
  713. # as ours) so, same thing
  714. return col
  715. else:
  716. # other cases where we are a selectable and the element
  717. # is another join or selectable that contains a table which our
  718. # selectable derives from, that we want to process
  719. return None
  720. elif not isinstance(col, ColumnElement):
  721. return None
  722. elif not _include_singleton_constants and col._is_singleton_constant:
  723. # dont swap out NULL, TRUE, FALSE for a label name
  724. # in a SQL statement that's being rewritten,
  725. # leave them as the constant. This is first noted in #6259,
  726. # however the logic to check this moved here as of #7154 so that
  727. # it is made specific to SQL rewriting and not all column
  728. # correspondence
  729. return None
  730. if "adapt_column" in col._annotations:
  731. col = col._annotations["adapt_column"]
  732. if self.adapt_from_selectables and col not in self.equivalents:
  733. for adp in self.adapt_from_selectables:
  734. if adp.c.corresponding_column(col, False) is not None:
  735. break
  736. else:
  737. return None
  738. if self.include_fn and not self.include_fn(col):
  739. return None
  740. elif self.exclude_fn and self.exclude_fn(col):
  741. return None
  742. else:
  743. return self._corresponding_column(col, True)
  744. class ColumnAdapter(ClauseAdapter):
  745. """Extends ClauseAdapter with extra utility functions.
  746. Key aspects of ColumnAdapter include:
  747. * Expressions that are adapted are stored in a persistent
  748. .columns collection; so that an expression E adapted into
  749. an expression E1, will return the same object E1 when adapted
  750. a second time. This is important in particular for things like
  751. Label objects that are anonymized, so that the ColumnAdapter can
  752. be used to present a consistent "adapted" view of things.
  753. * Exclusion of items from the persistent collection based on
  754. include/exclude rules, but also independent of hash identity.
  755. This because "annotated" items all have the same hash identity as their
  756. parent.
  757. * "wrapping" capability is added, so that the replacement of an expression
  758. E can proceed through a series of adapters. This differs from the
  759. visitor's "chaining" feature in that the resulting object is passed
  760. through all replacing functions unconditionally, rather than stopping
  761. at the first one that returns non-None.
  762. * An adapt_required option, used by eager loading to indicate that
  763. We don't trust a result row column that is not translated.
  764. This is to prevent a column from being interpreted as that
  765. of the child row in a self-referential scenario, see
  766. inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
  767. """
  768. def __init__(
  769. self,
  770. selectable,
  771. equivalents=None,
  772. adapt_required=False,
  773. include_fn=None,
  774. exclude_fn=None,
  775. adapt_on_names=False,
  776. allow_label_resolve=True,
  777. anonymize_labels=False,
  778. adapt_from_selectables=None,
  779. ):
  780. ClauseAdapter.__init__(
  781. self,
  782. selectable,
  783. equivalents,
  784. include_fn=include_fn,
  785. exclude_fn=exclude_fn,
  786. adapt_on_names=adapt_on_names,
  787. anonymize_labels=anonymize_labels,
  788. adapt_from_selectables=adapt_from_selectables,
  789. )
  790. self.columns = util.WeakPopulateDict(self._locate_col)
  791. if self.include_fn or self.exclude_fn:
  792. self.columns = self._IncludeExcludeMapping(self, self.columns)
  793. self.adapt_required = adapt_required
  794. self.allow_label_resolve = allow_label_resolve
  795. self._wrap = None
  796. class _IncludeExcludeMapping(object):
  797. def __init__(self, parent, columns):
  798. self.parent = parent
  799. self.columns = columns
  800. def __getitem__(self, key):
  801. if (
  802. self.parent.include_fn and not self.parent.include_fn(key)
  803. ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
  804. if self.parent._wrap:
  805. return self.parent._wrap.columns[key]
  806. else:
  807. return key
  808. return self.columns[key]
  809. def wrap(self, adapter):
  810. ac = self.__class__.__new__(self.__class__)
  811. ac.__dict__.update(self.__dict__)
  812. ac._wrap = adapter
  813. ac.columns = util.WeakPopulateDict(ac._locate_col)
  814. if ac.include_fn or ac.exclude_fn:
  815. ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
  816. return ac
  817. def traverse(self, obj):
  818. return self.columns[obj]
  819. adapt_clause = traverse
  820. adapt_list = ClauseAdapter.copy_and_process
  821. def adapt_check_present(self, col):
  822. newcol = self.columns[col]
  823. if newcol is col and self._corresponding_column(col, True) is None:
  824. return None
  825. return newcol
  826. def _locate_col(self, col):
  827. # both replace and traverse() are overly complicated for what
  828. # we are doing here and we would do better to have an inlined
  829. # version that doesn't build up as much overhead. the issue is that
  830. # sometimes the lookup does in fact have to adapt the insides of
  831. # say a labeled scalar subquery. However, if the object is an
  832. # Immutable, i.e. Column objects, we can skip the "clone" /
  833. # "copy internals" part since those will be no-ops in any case.
  834. # additionally we want to catch singleton objects null/true/false
  835. # and make sure they are adapted as well here.
  836. if col._is_immutable:
  837. for vis in self.visitor_iterator:
  838. c = vis.replace(col, _include_singleton_constants=True)
  839. if c is not None:
  840. break
  841. else:
  842. c = col
  843. else:
  844. c = ClauseAdapter.traverse(self, col)
  845. if self._wrap:
  846. c2 = self._wrap._locate_col(c)
  847. if c2 is not None:
  848. c = c2
  849. if self.adapt_required and c is col:
  850. return None
  851. c._allow_label_resolve = self.allow_label_resolve
  852. return c
  853. def __getstate__(self):
  854. d = self.__dict__.copy()
  855. del d["columns"]
  856. return d
  857. def __setstate__(self, state):
  858. self.__dict__.update(state)
  859. self.columns = util.WeakPopulateDict(self._locate_col)
  860. def _offset_or_limit_clause(element, name=None, type_=None):
  861. """Convert the given value to an "offset or limit" clause.
  862. This handles incoming integers and converts to an expression; if
  863. an expression is already given, it is passed through.
  864. """
  865. return coercions.expect(
  866. roles.LimitOffsetRole, element, name=name, type_=type_
  867. )
  868. def _offset_or_limit_clause_asint_if_possible(clause):
  869. """Return the offset or limit clause as a simple integer if possible,
  870. else return the clause.
  871. """
  872. if clause is None:
  873. return None
  874. if hasattr(clause, "_limit_offset_value"):
  875. value = clause._limit_offset_value
  876. return util.asint(value)
  877. else:
  878. return clause
  879. def _make_slice(limit_clause, offset_clause, start, stop):
  880. """Compute LIMIT/OFFSET in terms of slice start/end"""
  881. # for calculated limit/offset, try to do the addition of
  882. # values to offset in Python, however if a SQL clause is present
  883. # then the addition has to be on the SQL side.
  884. if start is not None and stop is not None:
  885. offset_clause = _offset_or_limit_clause_asint_if_possible(
  886. offset_clause
  887. )
  888. if offset_clause is None:
  889. offset_clause = 0
  890. if start != 0:
  891. offset_clause = offset_clause + start
  892. if offset_clause == 0:
  893. offset_clause = None
  894. else:
  895. offset_clause = _offset_or_limit_clause(offset_clause)
  896. limit_clause = _offset_or_limit_clause(stop - start)
  897. elif start is None and stop is not None:
  898. limit_clause = _offset_or_limit_clause(stop)
  899. elif start is not None and stop is None:
  900. offset_clause = _offset_or_limit_clause_asint_if_possible(
  901. offset_clause
  902. )
  903. if offset_clause is None:
  904. offset_clause = 0
  905. if start != 0:
  906. offset_clause = offset_clause + start
  907. if offset_clause == 0:
  908. offset_clause = None
  909. else:
  910. offset_clause = _offset_or_limit_clause(offset_clause)
  911. return limit_clause, offset_clause