persistence.py 82 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517
  1. # orm/persistence.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. """private module containing functions used to emit INSERT, UPDATE
  8. and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
  9. mappers.
  10. The functions here are called only by the unit of work functions
  11. in unitofwork.py.
  12. """
  13. from itertools import chain
  14. from itertools import groupby
  15. import operator
  16. from . import attributes
  17. from . import evaluator
  18. from . import exc as orm_exc
  19. from . import loading
  20. from . import sync
  21. from .base import NO_VALUE
  22. from .base import state_str
  23. from .. import exc as sa_exc
  24. from .. import future
  25. from .. import sql
  26. from .. import util
  27. from ..engine import result as _result
  28. from ..sql import coercions
  29. from ..sql import expression
  30. from ..sql import operators
  31. from ..sql import roles
  32. from ..sql import select
  33. from ..sql import sqltypes
  34. from ..sql.base import _entity_namespace_key
  35. from ..sql.base import CompileState
  36. from ..sql.base import Options
  37. from ..sql.dml import DeleteDMLState
  38. from ..sql.dml import InsertDMLState
  39. from ..sql.dml import UpdateDMLState
  40. from ..sql.elements import BooleanClauseList
  41. from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
  42. def _bulk_insert(
  43. mapper,
  44. mappings,
  45. session_transaction,
  46. isstates,
  47. return_defaults,
  48. render_nulls,
  49. ):
  50. base_mapper = mapper.base_mapper
  51. if session_transaction.session.connection_callable:
  52. raise NotImplementedError(
  53. "connection_callable / per-instance sharding "
  54. "not supported in bulk_insert()"
  55. )
  56. if isstates:
  57. if return_defaults:
  58. states = [(state, state.dict) for state in mappings]
  59. mappings = [dict_ for (state, dict_) in states]
  60. else:
  61. mappings = [state.dict for state in mappings]
  62. else:
  63. mappings = list(mappings)
  64. connection = session_transaction.connection(base_mapper)
  65. for table, super_mapper in base_mapper._sorted_tables.items():
  66. if not mapper.isa(super_mapper):
  67. continue
  68. records = (
  69. (
  70. None,
  71. state_dict,
  72. params,
  73. mapper,
  74. connection,
  75. value_params,
  76. has_all_pks,
  77. has_all_defaults,
  78. )
  79. for (
  80. state,
  81. state_dict,
  82. params,
  83. mp,
  84. conn,
  85. value_params,
  86. has_all_pks,
  87. has_all_defaults,
  88. ) in _collect_insert_commands(
  89. table,
  90. ((None, mapping, mapper, connection) for mapping in mappings),
  91. bulk=True,
  92. return_defaults=return_defaults,
  93. render_nulls=render_nulls,
  94. )
  95. )
  96. _emit_insert_statements(
  97. base_mapper,
  98. None,
  99. super_mapper,
  100. table,
  101. records,
  102. bookkeeping=return_defaults,
  103. )
  104. if return_defaults and isstates:
  105. identity_cls = mapper._identity_class
  106. identity_props = [p.key for p in mapper._identity_key_props]
  107. for state, dict_ in states:
  108. state.key = (
  109. identity_cls,
  110. tuple([dict_[key] for key in identity_props]),
  111. )
  112. def _bulk_update(
  113. mapper, mappings, session_transaction, isstates, update_changed_only
  114. ):
  115. base_mapper = mapper.base_mapper
  116. search_keys = mapper._primary_key_propkeys
  117. if mapper._version_id_prop:
  118. search_keys = {mapper._version_id_prop.key}.union(search_keys)
  119. def _changed_dict(mapper, state):
  120. return dict(
  121. (k, v)
  122. for k, v in state.dict.items()
  123. if k in state.committed_state or k in search_keys
  124. )
  125. if isstates:
  126. if update_changed_only:
  127. mappings = [_changed_dict(mapper, state) for state in mappings]
  128. else:
  129. mappings = [state.dict for state in mappings]
  130. else:
  131. mappings = list(mappings)
  132. if session_transaction.session.connection_callable:
  133. raise NotImplementedError(
  134. "connection_callable / per-instance sharding "
  135. "not supported in bulk_update()"
  136. )
  137. connection = session_transaction.connection(base_mapper)
  138. for table, super_mapper in base_mapper._sorted_tables.items():
  139. if not mapper.isa(super_mapper):
  140. continue
  141. records = _collect_update_commands(
  142. None,
  143. table,
  144. (
  145. (
  146. None,
  147. mapping,
  148. mapper,
  149. connection,
  150. (
  151. mapping[mapper._version_id_prop.key]
  152. if mapper._version_id_prop
  153. else None
  154. ),
  155. )
  156. for mapping in mappings
  157. ),
  158. bulk=True,
  159. )
  160. _emit_update_statements(
  161. base_mapper,
  162. None,
  163. super_mapper,
  164. table,
  165. records,
  166. bookkeeping=False,
  167. )
  168. def save_obj(base_mapper, states, uowtransaction, single=False):
  169. """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
  170. of objects.
  171. This is called within the context of a UOWTransaction during a
  172. flush operation, given a list of states to be flushed. The
  173. base mapper in an inheritance hierarchy handles the inserts/
  174. updates for all descendant mappers.
  175. """
  176. # if batch=false, call _save_obj separately for each object
  177. if not single and not base_mapper.batch:
  178. for state in _sort_states(base_mapper, states):
  179. save_obj(base_mapper, [state], uowtransaction, single=True)
  180. return
  181. states_to_update = []
  182. states_to_insert = []
  183. for (
  184. state,
  185. dict_,
  186. mapper,
  187. connection,
  188. has_identity,
  189. row_switch,
  190. update_version_id,
  191. ) in _organize_states_for_save(base_mapper, states, uowtransaction):
  192. if has_identity or row_switch:
  193. states_to_update.append(
  194. (state, dict_, mapper, connection, update_version_id)
  195. )
  196. else:
  197. states_to_insert.append((state, dict_, mapper, connection))
  198. for table, mapper in base_mapper._sorted_tables.items():
  199. if table not in mapper._pks_by_table:
  200. continue
  201. insert = _collect_insert_commands(table, states_to_insert)
  202. update = _collect_update_commands(
  203. uowtransaction, table, states_to_update
  204. )
  205. _emit_update_statements(
  206. base_mapper,
  207. uowtransaction,
  208. mapper,
  209. table,
  210. update,
  211. )
  212. _emit_insert_statements(
  213. base_mapper,
  214. uowtransaction,
  215. mapper,
  216. table,
  217. insert,
  218. )
  219. _finalize_insert_update_commands(
  220. base_mapper,
  221. uowtransaction,
  222. chain(
  223. (
  224. (state, state_dict, mapper, connection, False)
  225. for (state, state_dict, mapper, connection) in states_to_insert
  226. ),
  227. (
  228. (state, state_dict, mapper, connection, True)
  229. for (
  230. state,
  231. state_dict,
  232. mapper,
  233. connection,
  234. update_version_id,
  235. ) in states_to_update
  236. ),
  237. ),
  238. )
  239. def post_update(base_mapper, states, uowtransaction, post_update_cols):
  240. """Issue UPDATE statements on behalf of a relationship() which
  241. specifies post_update.
  242. """
  243. states_to_update = list(
  244. _organize_states_for_post_update(base_mapper, states, uowtransaction)
  245. )
  246. for table, mapper in base_mapper._sorted_tables.items():
  247. if table not in mapper._pks_by_table:
  248. continue
  249. update = (
  250. (
  251. state,
  252. state_dict,
  253. sub_mapper,
  254. connection,
  255. mapper._get_committed_state_attr_by_column(
  256. state, state_dict, mapper.version_id_col
  257. )
  258. if mapper.version_id_col is not None
  259. else None,
  260. )
  261. for state, state_dict, sub_mapper, connection in states_to_update
  262. if table in sub_mapper._pks_by_table
  263. )
  264. update = _collect_post_update_commands(
  265. base_mapper, uowtransaction, table, update, post_update_cols
  266. )
  267. _emit_post_update_statements(
  268. base_mapper,
  269. uowtransaction,
  270. mapper,
  271. table,
  272. update,
  273. )
  274. def delete_obj(base_mapper, states, uowtransaction):
  275. """Issue ``DELETE`` statements for a list of objects.
  276. This is called within the context of a UOWTransaction during a
  277. flush operation.
  278. """
  279. states_to_delete = list(
  280. _organize_states_for_delete(base_mapper, states, uowtransaction)
  281. )
  282. table_to_mapper = base_mapper._sorted_tables
  283. for table in reversed(list(table_to_mapper.keys())):
  284. mapper = table_to_mapper[table]
  285. if table not in mapper._pks_by_table:
  286. continue
  287. elif mapper.inherits and mapper.passive_deletes:
  288. continue
  289. delete = _collect_delete_commands(
  290. base_mapper, uowtransaction, table, states_to_delete
  291. )
  292. _emit_delete_statements(
  293. base_mapper,
  294. uowtransaction,
  295. mapper,
  296. table,
  297. delete,
  298. )
  299. for (
  300. state,
  301. state_dict,
  302. mapper,
  303. connection,
  304. update_version_id,
  305. ) in states_to_delete:
  306. mapper.dispatch.after_delete(mapper, connection, state)
  307. def _organize_states_for_save(base_mapper, states, uowtransaction):
  308. """Make an initial pass across a set of states for INSERT or
  309. UPDATE.
  310. This includes splitting out into distinct lists for
  311. each, calling before_insert/before_update, obtaining
  312. key information for each state including its dictionary,
  313. mapper, the connection to use for the execution per state,
  314. and the identity flag.
  315. """
  316. for state, dict_, mapper, connection in _connections_for_states(
  317. base_mapper, uowtransaction, states
  318. ):
  319. has_identity = bool(state.key)
  320. instance_key = state.key or mapper._identity_key_from_state(state)
  321. row_switch = update_version_id = None
  322. # call before_XXX extensions
  323. if not has_identity:
  324. mapper.dispatch.before_insert(mapper, connection, state)
  325. else:
  326. mapper.dispatch.before_update(mapper, connection, state)
  327. if mapper._validate_polymorphic_identity:
  328. mapper._validate_polymorphic_identity(mapper, state, dict_)
  329. # detect if we have a "pending" instance (i.e. has
  330. # no instance_key attached to it), and another instance
  331. # with the same identity key already exists as persistent.
  332. # convert to an UPDATE if so.
  333. if (
  334. not has_identity
  335. and instance_key in uowtransaction.session.identity_map
  336. ):
  337. instance = uowtransaction.session.identity_map[instance_key]
  338. existing = attributes.instance_state(instance)
  339. if not uowtransaction.was_already_deleted(existing):
  340. if not uowtransaction.is_deleted(existing):
  341. util.warn(
  342. "New instance %s with identity key %s conflicts "
  343. "with persistent instance %s"
  344. % (state_str(state), instance_key, state_str(existing))
  345. )
  346. else:
  347. base_mapper._log_debug(
  348. "detected row switch for identity %s. "
  349. "will update %s, remove %s from "
  350. "transaction",
  351. instance_key,
  352. state_str(state),
  353. state_str(existing),
  354. )
  355. # remove the "delete" flag from the existing element
  356. uowtransaction.remove_state_actions(existing)
  357. row_switch = existing
  358. if (has_identity or row_switch) and mapper.version_id_col is not None:
  359. update_version_id = mapper._get_committed_state_attr_by_column(
  360. row_switch if row_switch else state,
  361. row_switch.dict if row_switch else dict_,
  362. mapper.version_id_col,
  363. )
  364. yield (
  365. state,
  366. dict_,
  367. mapper,
  368. connection,
  369. has_identity,
  370. row_switch,
  371. update_version_id,
  372. )
  373. def _organize_states_for_post_update(base_mapper, states, uowtransaction):
  374. """Make an initial pass across a set of states for UPDATE
  375. corresponding to post_update.
  376. This includes obtaining key information for each state
  377. including its dictionary, mapper, the connection to use for
  378. the execution per state.
  379. """
  380. return _connections_for_states(base_mapper, uowtransaction, states)
  381. def _organize_states_for_delete(base_mapper, states, uowtransaction):
  382. """Make an initial pass across a set of states for DELETE.
  383. This includes calling out before_delete and obtaining
  384. key information for each state including its dictionary,
  385. mapper, the connection to use for the execution per state.
  386. """
  387. for state, dict_, mapper, connection in _connections_for_states(
  388. base_mapper, uowtransaction, states
  389. ):
  390. mapper.dispatch.before_delete(mapper, connection, state)
  391. if mapper.version_id_col is not None:
  392. update_version_id = mapper._get_committed_state_attr_by_column(
  393. state, dict_, mapper.version_id_col
  394. )
  395. else:
  396. update_version_id = None
  397. yield (state, dict_, mapper, connection, update_version_id)
  398. def _collect_insert_commands(
  399. table,
  400. states_to_insert,
  401. bulk=False,
  402. return_defaults=False,
  403. render_nulls=False,
  404. ):
  405. """Identify sets of values to use in INSERT statements for a
  406. list of states.
  407. """
  408. for state, state_dict, mapper, connection in states_to_insert:
  409. if table not in mapper._pks_by_table:
  410. continue
  411. params = {}
  412. value_params = {}
  413. propkey_to_col = mapper._propkey_to_col[table]
  414. eval_none = mapper._insert_cols_evaluating_none[table]
  415. for propkey in set(propkey_to_col).intersection(state_dict):
  416. value = state_dict[propkey]
  417. col = propkey_to_col[propkey]
  418. if value is None and col not in eval_none and not render_nulls:
  419. continue
  420. elif not bulk and (
  421. hasattr(value, "__clause_element__")
  422. or isinstance(value, sql.ClauseElement)
  423. ):
  424. value_params[col] = (
  425. value.__clause_element__()
  426. if hasattr(value, "__clause_element__")
  427. else value
  428. )
  429. else:
  430. params[col.key] = value
  431. if not bulk:
  432. # for all the columns that have no default and we don't have
  433. # a value and where "None" is not a special value, add
  434. # explicit None to the INSERT. This is a legacy behavior
  435. # which might be worth removing, as it should not be necessary
  436. # and also produces confusion, given that "missing" and None
  437. # now have distinct meanings
  438. for colkey in (
  439. mapper._insert_cols_as_none[table]
  440. .difference(params)
  441. .difference([c.key for c in value_params])
  442. ):
  443. params[colkey] = None
  444. if not bulk or return_defaults:
  445. # params are in terms of Column key objects, so
  446. # compare to pk_keys_by_table
  447. has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
  448. if mapper.base_mapper.eager_defaults:
  449. has_all_defaults = mapper._server_default_cols[table].issubset(
  450. params
  451. )
  452. else:
  453. has_all_defaults = True
  454. else:
  455. has_all_defaults = has_all_pks = True
  456. if (
  457. mapper.version_id_generator is not False
  458. and mapper.version_id_col is not None
  459. and mapper.version_id_col in mapper._cols_by_table[table]
  460. ):
  461. params[mapper.version_id_col.key] = mapper.version_id_generator(
  462. None
  463. )
  464. yield (
  465. state,
  466. state_dict,
  467. params,
  468. mapper,
  469. connection,
  470. value_params,
  471. has_all_pks,
  472. has_all_defaults,
  473. )
  474. def _collect_update_commands(
  475. uowtransaction, table, states_to_update, bulk=False
  476. ):
  477. """Identify sets of values to use in UPDATE statements for a
  478. list of states.
  479. This function works intricately with the history system
  480. to determine exactly what values should be updated
  481. as well as how the row should be matched within an UPDATE
  482. statement. Includes some tricky scenarios where the primary
  483. key of an object might have been changed.
  484. """
  485. for (
  486. state,
  487. state_dict,
  488. mapper,
  489. connection,
  490. update_version_id,
  491. ) in states_to_update:
  492. if table not in mapper._pks_by_table:
  493. continue
  494. pks = mapper._pks_by_table[table]
  495. value_params = {}
  496. propkey_to_col = mapper._propkey_to_col[table]
  497. if bulk:
  498. # keys here are mapped attribute keys, so
  499. # look at mapper attribute keys for pk
  500. params = dict(
  501. (propkey_to_col[propkey].key, state_dict[propkey])
  502. for propkey in set(propkey_to_col)
  503. .intersection(state_dict)
  504. .difference(mapper._pk_attr_keys_by_table[table])
  505. )
  506. has_all_defaults = True
  507. else:
  508. params = {}
  509. for propkey in set(propkey_to_col).intersection(
  510. state.committed_state
  511. ):
  512. value = state_dict[propkey]
  513. col = propkey_to_col[propkey]
  514. if hasattr(value, "__clause_element__") or isinstance(
  515. value, sql.ClauseElement
  516. ):
  517. value_params[col] = (
  518. value.__clause_element__()
  519. if hasattr(value, "__clause_element__")
  520. else value
  521. )
  522. # guard against values that generate non-__nonzero__
  523. # objects for __eq__()
  524. elif (
  525. state.manager[propkey].impl.is_equal(
  526. value, state.committed_state[propkey]
  527. )
  528. is not True
  529. ):
  530. params[col.key] = value
  531. if mapper.base_mapper.eager_defaults:
  532. has_all_defaults = (
  533. mapper._server_onupdate_default_cols[table]
  534. ).issubset(params)
  535. else:
  536. has_all_defaults = True
  537. if (
  538. update_version_id is not None
  539. and mapper.version_id_col in mapper._cols_by_table[table]
  540. ):
  541. if not bulk and not (params or value_params):
  542. # HACK: check for history in other tables, in case the
  543. # history is only in a different table than the one
  544. # where the version_id_col is. This logic was lost
  545. # from 0.9 -> 1.0.0 and restored in 1.0.6.
  546. for prop in mapper._columntoproperty.values():
  547. history = state.manager[prop.key].impl.get_history(
  548. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  549. )
  550. if history.added:
  551. break
  552. else:
  553. # no net change, break
  554. continue
  555. col = mapper.version_id_col
  556. no_params = not params and not value_params
  557. params[col._label] = update_version_id
  558. if (
  559. bulk or col.key not in params
  560. ) and mapper.version_id_generator is not False:
  561. val = mapper.version_id_generator(update_version_id)
  562. params[col.key] = val
  563. elif mapper.version_id_generator is False and no_params:
  564. # no version id generator, no values set on the table,
  565. # and version id wasn't manually incremented.
  566. # set version id to itself so we get an UPDATE
  567. # statement
  568. params[col.key] = update_version_id
  569. elif not (params or value_params):
  570. continue
  571. has_all_pks = True
  572. expect_pk_cascaded = False
  573. if bulk:
  574. # keys here are mapped attribute keys, so
  575. # look at mapper attribute keys for pk
  576. pk_params = dict(
  577. (propkey_to_col[propkey]._label, state_dict.get(propkey))
  578. for propkey in set(propkey_to_col).intersection(
  579. mapper._pk_attr_keys_by_table[table]
  580. )
  581. )
  582. else:
  583. pk_params = {}
  584. for col in pks:
  585. propkey = mapper._columntoproperty[col].key
  586. history = state.manager[propkey].impl.get_history(
  587. state, state_dict, attributes.PASSIVE_OFF
  588. )
  589. if history.added:
  590. if (
  591. not history.deleted
  592. or ("pk_cascaded", state, col)
  593. in uowtransaction.attributes
  594. ):
  595. expect_pk_cascaded = True
  596. pk_params[col._label] = history.added[0]
  597. params.pop(col.key, None)
  598. else:
  599. # else, use the old value to locate the row
  600. pk_params[col._label] = history.deleted[0]
  601. if col in value_params:
  602. has_all_pks = False
  603. else:
  604. pk_params[col._label] = history.unchanged[0]
  605. if pk_params[col._label] is None:
  606. raise orm_exc.FlushError(
  607. "Can't update table %s using NULL for primary "
  608. "key value on column %s" % (table, col)
  609. )
  610. if params or value_params:
  611. params.update(pk_params)
  612. yield (
  613. state,
  614. state_dict,
  615. params,
  616. mapper,
  617. connection,
  618. value_params,
  619. has_all_defaults,
  620. has_all_pks,
  621. )
  622. elif expect_pk_cascaded:
  623. # no UPDATE occurs on this table, but we expect that CASCADE rules
  624. # have changed the primary key of the row; propagate this event to
  625. # other columns that expect to have been modified. this normally
  626. # occurs after the UPDATE is emitted however we invoke it here
  627. # explicitly in the absence of our invoking an UPDATE
  628. for m, equated_pairs in mapper._table_to_equated[table]:
  629. sync.populate(
  630. state,
  631. m,
  632. state,
  633. m,
  634. equated_pairs,
  635. uowtransaction,
  636. mapper.passive_updates,
  637. )
  638. def _collect_post_update_commands(
  639. base_mapper, uowtransaction, table, states_to_update, post_update_cols
  640. ):
  641. """Identify sets of values to use in UPDATE statements for a
  642. list of states within a post_update operation.
  643. """
  644. for (
  645. state,
  646. state_dict,
  647. mapper,
  648. connection,
  649. update_version_id,
  650. ) in states_to_update:
  651. # assert table in mapper._pks_by_table
  652. pks = mapper._pks_by_table[table]
  653. params = {}
  654. hasdata = False
  655. for col in mapper._cols_by_table[table]:
  656. if col in pks:
  657. params[col._label] = mapper._get_state_attr_by_column(
  658. state, state_dict, col, passive=attributes.PASSIVE_OFF
  659. )
  660. elif col in post_update_cols or col.onupdate is not None:
  661. prop = mapper._columntoproperty[col]
  662. history = state.manager[prop.key].impl.get_history(
  663. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  664. )
  665. if history.added:
  666. value = history.added[0]
  667. params[col.key] = value
  668. hasdata = True
  669. if hasdata:
  670. if (
  671. update_version_id is not None
  672. and mapper.version_id_col in mapper._cols_by_table[table]
  673. ):
  674. col = mapper.version_id_col
  675. params[col._label] = update_version_id
  676. if (
  677. bool(state.key)
  678. and col.key not in params
  679. and mapper.version_id_generator is not False
  680. ):
  681. val = mapper.version_id_generator(update_version_id)
  682. params[col.key] = val
  683. yield state, state_dict, mapper, connection, params
  684. def _collect_delete_commands(
  685. base_mapper, uowtransaction, table, states_to_delete
  686. ):
  687. """Identify values to use in DELETE statements for a list of
  688. states to be deleted."""
  689. for (
  690. state,
  691. state_dict,
  692. mapper,
  693. connection,
  694. update_version_id,
  695. ) in states_to_delete:
  696. if table not in mapper._pks_by_table:
  697. continue
  698. params = {}
  699. for col in mapper._pks_by_table[table]:
  700. params[
  701. col.key
  702. ] = value = mapper._get_committed_state_attr_by_column(
  703. state, state_dict, col
  704. )
  705. if value is None:
  706. raise orm_exc.FlushError(
  707. "Can't delete from table %s "
  708. "using NULL for primary "
  709. "key value on column %s" % (table, col)
  710. )
  711. if (
  712. update_version_id is not None
  713. and mapper.version_id_col in mapper._cols_by_table[table]
  714. ):
  715. params[mapper.version_id_col.key] = update_version_id
  716. yield params, connection
  717. def _emit_update_statements(
  718. base_mapper,
  719. uowtransaction,
  720. mapper,
  721. table,
  722. update,
  723. bookkeeping=True,
  724. ):
  725. """Emit UPDATE statements corresponding to value lists collected
  726. by _collect_update_commands()."""
  727. needs_version_id = (
  728. mapper.version_id_col is not None
  729. and mapper.version_id_col in mapper._cols_by_table[table]
  730. )
  731. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  732. def update_stmt():
  733. clauses = BooleanClauseList._construct_raw(operators.and_)
  734. for col in mapper._pks_by_table[table]:
  735. clauses.clauses.append(
  736. col == sql.bindparam(col._label, type_=col.type)
  737. )
  738. if needs_version_id:
  739. clauses.clauses.append(
  740. mapper.version_id_col
  741. == sql.bindparam(
  742. mapper.version_id_col._label,
  743. type_=mapper.version_id_col.type,
  744. )
  745. )
  746. stmt = table.update().where(clauses)
  747. return stmt
  748. cached_stmt = base_mapper._memo(("update", table), update_stmt)
  749. for (
  750. (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
  751. records,
  752. ) in groupby(
  753. update,
  754. lambda rec: (
  755. rec[4], # connection
  756. set(rec[2]), # set of parameter keys
  757. bool(rec[5]), # whether or not we have "value" parameters
  758. rec[6], # has_all_defaults
  759. rec[7], # has all pks
  760. ),
  761. ):
  762. rows = 0
  763. records = list(records)
  764. statement = cached_stmt
  765. return_defaults = False
  766. if not has_all_pks:
  767. statement = statement.return_defaults()
  768. return_defaults = True
  769. elif (
  770. bookkeeping
  771. and not has_all_defaults
  772. and mapper.base_mapper.eager_defaults
  773. ):
  774. statement = statement.return_defaults()
  775. return_defaults = True
  776. elif mapper.version_id_col is not None:
  777. statement = statement.return_defaults(mapper.version_id_col)
  778. return_defaults = True
  779. assert_singlerow = (
  780. connection.dialect.supports_sane_rowcount
  781. if not return_defaults
  782. else connection.dialect.supports_sane_rowcount_returning
  783. )
  784. assert_multirow = (
  785. assert_singlerow
  786. and connection.dialect.supports_sane_multi_rowcount
  787. )
  788. allow_multirow = has_all_defaults and not needs_version_id
  789. if hasvalue:
  790. for (
  791. state,
  792. state_dict,
  793. params,
  794. mapper,
  795. connection,
  796. value_params,
  797. has_all_defaults,
  798. has_all_pks,
  799. ) in records:
  800. c = connection._execute_20(
  801. statement.values(value_params),
  802. params,
  803. execution_options=execution_options,
  804. )
  805. if bookkeeping:
  806. _postfetch(
  807. mapper,
  808. uowtransaction,
  809. table,
  810. state,
  811. state_dict,
  812. c,
  813. c.context.compiled_parameters[0],
  814. value_params,
  815. True,
  816. c.returned_defaults,
  817. )
  818. rows += c.rowcount
  819. check_rowcount = assert_singlerow
  820. else:
  821. if not allow_multirow:
  822. check_rowcount = assert_singlerow
  823. for (
  824. state,
  825. state_dict,
  826. params,
  827. mapper,
  828. connection,
  829. value_params,
  830. has_all_defaults,
  831. has_all_pks,
  832. ) in records:
  833. c = connection._execute_20(
  834. statement, params, execution_options=execution_options
  835. )
  836. # TODO: why with bookkeeping=False?
  837. if bookkeeping:
  838. _postfetch(
  839. mapper,
  840. uowtransaction,
  841. table,
  842. state,
  843. state_dict,
  844. c,
  845. c.context.compiled_parameters[0],
  846. value_params,
  847. True,
  848. c.returned_defaults,
  849. )
  850. rows += c.rowcount
  851. else:
  852. multiparams = [rec[2] for rec in records]
  853. check_rowcount = assert_multirow or (
  854. assert_singlerow and len(multiparams) == 1
  855. )
  856. c = connection._execute_20(
  857. statement, multiparams, execution_options=execution_options
  858. )
  859. rows += c.rowcount
  860. for (
  861. state,
  862. state_dict,
  863. params,
  864. mapper,
  865. connection,
  866. value_params,
  867. has_all_defaults,
  868. has_all_pks,
  869. ) in records:
  870. if bookkeeping:
  871. _postfetch(
  872. mapper,
  873. uowtransaction,
  874. table,
  875. state,
  876. state_dict,
  877. c,
  878. c.context.compiled_parameters[0],
  879. value_params,
  880. True,
  881. c.returned_defaults
  882. if not c.context.executemany
  883. else None,
  884. )
  885. if check_rowcount:
  886. if rows != len(records):
  887. raise orm_exc.StaleDataError(
  888. "UPDATE statement on table '%s' expected to "
  889. "update %d row(s); %d were matched."
  890. % (table.description, len(records), rows)
  891. )
  892. elif needs_version_id:
  893. util.warn(
  894. "Dialect %s does not support updated rowcount "
  895. "- versioning cannot be verified."
  896. % c.dialect.dialect_description
  897. )
  898. def _emit_insert_statements(
  899. base_mapper,
  900. uowtransaction,
  901. mapper,
  902. table,
  903. insert,
  904. bookkeeping=True,
  905. ):
  906. """Emit INSERT statements corresponding to value lists collected
  907. by _collect_insert_commands()."""
  908. cached_stmt = base_mapper._memo(("insert", table), table.insert)
  909. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  910. for (
  911. (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
  912. records,
  913. ) in groupby(
  914. insert,
  915. lambda rec: (
  916. rec[4], # connection
  917. set(rec[2]), # parameter keys
  918. bool(rec[5]), # whether we have "value" parameters
  919. rec[6],
  920. rec[7],
  921. ),
  922. ):
  923. statement = cached_stmt
  924. if (
  925. not bookkeeping
  926. or (
  927. has_all_defaults
  928. or not base_mapper.eager_defaults
  929. or not connection.dialect.implicit_returning
  930. )
  931. and has_all_pks
  932. and not hasvalue
  933. ):
  934. # the "we don't need newly generated values back" section.
  935. # here we have all the PKs, all the defaults or we don't want
  936. # to fetch them, or the dialect doesn't support RETURNING at all
  937. # so we have to post-fetch / use lastrowid anyway.
  938. records = list(records)
  939. multiparams = [rec[2] for rec in records]
  940. c = connection._execute_20(
  941. statement, multiparams, execution_options=execution_options
  942. )
  943. if bookkeeping:
  944. for (
  945. (
  946. state,
  947. state_dict,
  948. params,
  949. mapper_rec,
  950. conn,
  951. value_params,
  952. has_all_pks,
  953. has_all_defaults,
  954. ),
  955. last_inserted_params,
  956. ) in zip(records, c.context.compiled_parameters):
  957. if state:
  958. _postfetch(
  959. mapper_rec,
  960. uowtransaction,
  961. table,
  962. state,
  963. state_dict,
  964. c,
  965. last_inserted_params,
  966. value_params,
  967. False,
  968. c.returned_defaults
  969. if not c.context.executemany
  970. else None,
  971. )
  972. else:
  973. _postfetch_bulk_save(mapper_rec, state_dict, table)
  974. else:
  975. # here, we need defaults and/or pk values back.
  976. records = list(records)
  977. if (
  978. not hasvalue
  979. and connection.dialect.insert_executemany_returning
  980. and len(records) > 1
  981. ):
  982. do_executemany = True
  983. else:
  984. do_executemany = False
  985. if not has_all_defaults and base_mapper.eager_defaults:
  986. statement = statement.return_defaults()
  987. elif mapper.version_id_col is not None:
  988. statement = statement.return_defaults(mapper.version_id_col)
  989. elif do_executemany:
  990. statement = statement.return_defaults(*table.primary_key)
  991. if do_executemany:
  992. multiparams = [rec[2] for rec in records]
  993. c = connection._execute_20(
  994. statement, multiparams, execution_options=execution_options
  995. )
  996. if bookkeeping:
  997. for (
  998. (
  999. state,
  1000. state_dict,
  1001. params,
  1002. mapper_rec,
  1003. conn,
  1004. value_params,
  1005. has_all_pks,
  1006. has_all_defaults,
  1007. ),
  1008. last_inserted_params,
  1009. inserted_primary_key,
  1010. returned_defaults,
  1011. ) in util.zip_longest(
  1012. records,
  1013. c.context.compiled_parameters,
  1014. c.inserted_primary_key_rows,
  1015. c.returned_defaults_rows or (),
  1016. ):
  1017. if inserted_primary_key is None:
  1018. # this is a real problem and means that we didn't
  1019. # get back as many PK rows. we can't continue
  1020. # since this indicates PK rows were missing, which
  1021. # means we likely mis-populated records starting
  1022. # at that point with incorrectly matched PK
  1023. # values.
  1024. raise orm_exc.FlushError(
  1025. "Multi-row INSERT statement for %s did not "
  1026. "produce "
  1027. "the correct number of INSERTed rows for "
  1028. "RETURNING. Ensure there are no triggers or "
  1029. "special driver issues preventing INSERT from "
  1030. "functioning properly." % mapper_rec
  1031. )
  1032. for pk, col in zip(
  1033. inserted_primary_key,
  1034. mapper._pks_by_table[table],
  1035. ):
  1036. prop = mapper_rec._columntoproperty[col]
  1037. if state_dict.get(prop.key) is None:
  1038. state_dict[prop.key] = pk
  1039. if state:
  1040. _postfetch(
  1041. mapper_rec,
  1042. uowtransaction,
  1043. table,
  1044. state,
  1045. state_dict,
  1046. c,
  1047. last_inserted_params,
  1048. value_params,
  1049. False,
  1050. returned_defaults,
  1051. )
  1052. else:
  1053. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1054. else:
  1055. for (
  1056. state,
  1057. state_dict,
  1058. params,
  1059. mapper_rec,
  1060. connection,
  1061. value_params,
  1062. has_all_pks,
  1063. has_all_defaults,
  1064. ) in records:
  1065. if value_params:
  1066. result = connection._execute_20(
  1067. statement.values(value_params),
  1068. params,
  1069. execution_options=execution_options,
  1070. )
  1071. else:
  1072. result = connection._execute_20(
  1073. statement,
  1074. params,
  1075. execution_options=execution_options,
  1076. )
  1077. primary_key = result.inserted_primary_key
  1078. if primary_key is None:
  1079. raise orm_exc.FlushError(
  1080. "Single-row INSERT statement for %s "
  1081. "did not produce a "
  1082. "new primary key result "
  1083. "being invoked. Ensure there are no triggers or "
  1084. "special driver issues preventing INSERT from "
  1085. "functioning properly." % (mapper_rec,)
  1086. )
  1087. for pk, col in zip(
  1088. primary_key, mapper._pks_by_table[table]
  1089. ):
  1090. prop = mapper_rec._columntoproperty[col]
  1091. if (
  1092. col in value_params
  1093. or state_dict.get(prop.key) is None
  1094. ):
  1095. state_dict[prop.key] = pk
  1096. if bookkeeping:
  1097. if state:
  1098. _postfetch(
  1099. mapper_rec,
  1100. uowtransaction,
  1101. table,
  1102. state,
  1103. state_dict,
  1104. result,
  1105. result.context.compiled_parameters[0],
  1106. value_params,
  1107. False,
  1108. result.returned_defaults
  1109. if not result.context.executemany
  1110. else None,
  1111. )
  1112. else:
  1113. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1114. def _emit_post_update_statements(
  1115. base_mapper, uowtransaction, mapper, table, update
  1116. ):
  1117. """Emit UPDATE statements corresponding to value lists collected
  1118. by _collect_post_update_commands()."""
  1119. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1120. needs_version_id = (
  1121. mapper.version_id_col is not None
  1122. and mapper.version_id_col in mapper._cols_by_table[table]
  1123. )
  1124. def update_stmt():
  1125. clauses = BooleanClauseList._construct_raw(operators.and_)
  1126. for col in mapper._pks_by_table[table]:
  1127. clauses.clauses.append(
  1128. col == sql.bindparam(col._label, type_=col.type)
  1129. )
  1130. if needs_version_id:
  1131. clauses.clauses.append(
  1132. mapper.version_id_col
  1133. == sql.bindparam(
  1134. mapper.version_id_col._label,
  1135. type_=mapper.version_id_col.type,
  1136. )
  1137. )
  1138. stmt = table.update().where(clauses)
  1139. if mapper.version_id_col is not None:
  1140. stmt = stmt.return_defaults(mapper.version_id_col)
  1141. return stmt
  1142. statement = base_mapper._memo(("post_update", table), update_stmt)
  1143. # execute each UPDATE in the order according to the original
  1144. # list of states to guarantee row access order, but
  1145. # also group them into common (connection, cols) sets
  1146. # to support executemany().
  1147. for key, records in groupby(
  1148. update,
  1149. lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
  1150. ):
  1151. rows = 0
  1152. records = list(records)
  1153. connection = key[0]
  1154. assert_singlerow = (
  1155. connection.dialect.supports_sane_rowcount
  1156. if mapper.version_id_col is None
  1157. else connection.dialect.supports_sane_rowcount_returning
  1158. )
  1159. assert_multirow = (
  1160. assert_singlerow
  1161. and connection.dialect.supports_sane_multi_rowcount
  1162. )
  1163. allow_multirow = not needs_version_id or assert_multirow
  1164. if not allow_multirow:
  1165. check_rowcount = assert_singlerow
  1166. for state, state_dict, mapper_rec, connection, params in records:
  1167. c = connection._execute_20(
  1168. statement, params, execution_options=execution_options
  1169. )
  1170. _postfetch_post_update(
  1171. mapper_rec,
  1172. uowtransaction,
  1173. table,
  1174. state,
  1175. state_dict,
  1176. c,
  1177. c.context.compiled_parameters[0],
  1178. )
  1179. rows += c.rowcount
  1180. else:
  1181. multiparams = [
  1182. params
  1183. for state, state_dict, mapper_rec, conn, params in records
  1184. ]
  1185. check_rowcount = assert_multirow or (
  1186. assert_singlerow and len(multiparams) == 1
  1187. )
  1188. c = connection._execute_20(
  1189. statement, multiparams, execution_options=execution_options
  1190. )
  1191. rows += c.rowcount
  1192. for state, state_dict, mapper_rec, connection, params in records:
  1193. _postfetch_post_update(
  1194. mapper_rec,
  1195. uowtransaction,
  1196. table,
  1197. state,
  1198. state_dict,
  1199. c,
  1200. c.context.compiled_parameters[0],
  1201. )
  1202. if check_rowcount:
  1203. if rows != len(records):
  1204. raise orm_exc.StaleDataError(
  1205. "UPDATE statement on table '%s' expected to "
  1206. "update %d row(s); %d were matched."
  1207. % (table.description, len(records), rows)
  1208. )
  1209. elif needs_version_id:
  1210. util.warn(
  1211. "Dialect %s does not support updated rowcount "
  1212. "- versioning cannot be verified."
  1213. % c.dialect.dialect_description
  1214. )
  1215. def _emit_delete_statements(
  1216. base_mapper, uowtransaction, mapper, table, delete
  1217. ):
  1218. """Emit DELETE statements corresponding to value lists collected
  1219. by _collect_delete_commands()."""
  1220. need_version_id = (
  1221. mapper.version_id_col is not None
  1222. and mapper.version_id_col in mapper._cols_by_table[table]
  1223. )
  1224. def delete_stmt():
  1225. clauses = BooleanClauseList._construct_raw(operators.and_)
  1226. for col in mapper._pks_by_table[table]:
  1227. clauses.clauses.append(
  1228. col == sql.bindparam(col.key, type_=col.type)
  1229. )
  1230. if need_version_id:
  1231. clauses.clauses.append(
  1232. mapper.version_id_col
  1233. == sql.bindparam(
  1234. mapper.version_id_col.key, type_=mapper.version_id_col.type
  1235. )
  1236. )
  1237. return table.delete().where(clauses)
  1238. statement = base_mapper._memo(("delete", table), delete_stmt)
  1239. for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
  1240. del_objects = [params for params, connection in recs]
  1241. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1242. expected = len(del_objects)
  1243. rows_matched = -1
  1244. only_warn = False
  1245. if (
  1246. need_version_id
  1247. and not connection.dialect.supports_sane_multi_rowcount
  1248. ):
  1249. if connection.dialect.supports_sane_rowcount:
  1250. rows_matched = 0
  1251. # execute deletes individually so that versioned
  1252. # rows can be verified
  1253. for params in del_objects:
  1254. c = connection._execute_20(
  1255. statement, params, execution_options=execution_options
  1256. )
  1257. rows_matched += c.rowcount
  1258. else:
  1259. util.warn(
  1260. "Dialect %s does not support deleted rowcount "
  1261. "- versioning cannot be verified."
  1262. % connection.dialect.dialect_description
  1263. )
  1264. connection._execute_20(
  1265. statement, del_objects, execution_options=execution_options
  1266. )
  1267. else:
  1268. c = connection._execute_20(
  1269. statement, del_objects, execution_options=execution_options
  1270. )
  1271. if not need_version_id:
  1272. only_warn = True
  1273. rows_matched = c.rowcount
  1274. if (
  1275. base_mapper.confirm_deleted_rows
  1276. and rows_matched > -1
  1277. and expected != rows_matched
  1278. and (
  1279. connection.dialect.supports_sane_multi_rowcount
  1280. or len(del_objects) == 1
  1281. )
  1282. ):
  1283. # TODO: why does this "only warn" if versioning is turned off,
  1284. # whereas the UPDATE raises?
  1285. if only_warn:
  1286. util.warn(
  1287. "DELETE statement on table '%s' expected to "
  1288. "delete %d row(s); %d were matched. Please set "
  1289. "confirm_deleted_rows=False within the mapper "
  1290. "configuration to prevent this warning."
  1291. % (table.description, expected, rows_matched)
  1292. )
  1293. else:
  1294. raise orm_exc.StaleDataError(
  1295. "DELETE statement on table '%s' expected to "
  1296. "delete %d row(s); %d were matched. Please set "
  1297. "confirm_deleted_rows=False within the mapper "
  1298. "configuration to prevent this warning."
  1299. % (table.description, expected, rows_matched)
  1300. )
  1301. def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
  1302. """finalize state on states that have been inserted or updated,
  1303. including calling after_insert/after_update events.
  1304. """
  1305. for state, state_dict, mapper, connection, has_identity in states:
  1306. if mapper._readonly_props:
  1307. readonly = state.unmodified_intersection(
  1308. [
  1309. p.key
  1310. for p in mapper._readonly_props
  1311. if (
  1312. p.expire_on_flush
  1313. and (not p.deferred or p.key in state.dict)
  1314. )
  1315. or (
  1316. not p.expire_on_flush
  1317. and not p.deferred
  1318. and p.key not in state.dict
  1319. )
  1320. ]
  1321. )
  1322. if readonly:
  1323. state._expire_attributes(state.dict, readonly)
  1324. # if eager_defaults option is enabled, load
  1325. # all expired cols. Else if we have a version_id_col, make sure
  1326. # it isn't expired.
  1327. toload_now = []
  1328. if base_mapper.eager_defaults:
  1329. toload_now.extend(
  1330. state._unloaded_non_object.intersection(
  1331. mapper._server_default_plus_onupdate_propkeys
  1332. )
  1333. )
  1334. if (
  1335. mapper.version_id_col is not None
  1336. and mapper.version_id_generator is False
  1337. ):
  1338. if mapper._version_id_prop.key in state.unloaded:
  1339. toload_now.extend([mapper._version_id_prop.key])
  1340. if toload_now:
  1341. state.key = base_mapper._identity_key_from_state(state)
  1342. stmt = future.select(mapper).set_label_style(
  1343. LABEL_STYLE_TABLENAME_PLUS_COL
  1344. )
  1345. loading.load_on_ident(
  1346. uowtransaction.session,
  1347. stmt,
  1348. state.key,
  1349. refresh_state=state,
  1350. only_load_props=toload_now,
  1351. )
  1352. # call after_XXX extensions
  1353. if not has_identity:
  1354. mapper.dispatch.after_insert(mapper, connection, state)
  1355. else:
  1356. mapper.dispatch.after_update(mapper, connection, state)
  1357. if (
  1358. mapper.version_id_generator is False
  1359. and mapper.version_id_col is not None
  1360. ):
  1361. if state_dict[mapper._version_id_prop.key] is None:
  1362. raise orm_exc.FlushError(
  1363. "Instance does not contain a non-NULL version value"
  1364. )
  1365. def _postfetch_post_update(
  1366. mapper, uowtransaction, table, state, dict_, result, params
  1367. ):
  1368. if uowtransaction.is_deleted(state):
  1369. return
  1370. prefetch_cols = result.context.compiled.prefetch
  1371. postfetch_cols = result.context.compiled.postfetch
  1372. if (
  1373. mapper.version_id_col is not None
  1374. and mapper.version_id_col in mapper._cols_by_table[table]
  1375. ):
  1376. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1377. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1378. if refresh_flush:
  1379. load_evt_attrs = []
  1380. for c in prefetch_cols:
  1381. if c.key in params and c in mapper._columntoproperty:
  1382. dict_[mapper._columntoproperty[c].key] = params[c.key]
  1383. if refresh_flush:
  1384. load_evt_attrs.append(mapper._columntoproperty[c].key)
  1385. if refresh_flush and load_evt_attrs:
  1386. mapper.class_manager.dispatch.refresh_flush(
  1387. state, uowtransaction, load_evt_attrs
  1388. )
  1389. if postfetch_cols:
  1390. state._expire_attributes(
  1391. state.dict,
  1392. [
  1393. mapper._columntoproperty[c].key
  1394. for c in postfetch_cols
  1395. if c in mapper._columntoproperty
  1396. ],
  1397. )
  1398. def _postfetch(
  1399. mapper,
  1400. uowtransaction,
  1401. table,
  1402. state,
  1403. dict_,
  1404. result,
  1405. params,
  1406. value_params,
  1407. isupdate,
  1408. returned_defaults,
  1409. ):
  1410. """Expire attributes in need of newly persisted database state,
  1411. after an INSERT or UPDATE statement has proceeded for that
  1412. state."""
  1413. prefetch_cols = result.context.compiled.prefetch
  1414. postfetch_cols = result.context.compiled.postfetch
  1415. returning_cols = result.context.compiled.returning
  1416. if (
  1417. mapper.version_id_col is not None
  1418. and mapper.version_id_col in mapper._cols_by_table[table]
  1419. ):
  1420. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1421. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1422. if refresh_flush:
  1423. load_evt_attrs = []
  1424. if returning_cols:
  1425. row = returned_defaults
  1426. if row is not None:
  1427. for row_value, col in zip(row, returning_cols):
  1428. # pk cols returned from insert are handled
  1429. # distinctly, don't step on the values here
  1430. if col.primary_key and result.context.isinsert:
  1431. continue
  1432. # note that columns can be in the "return defaults" that are
  1433. # not mapped to this mapper, typically because they are
  1434. # "excluded", which can be specified directly or also occurs
  1435. # when using declarative w/ single table inheritance
  1436. prop = mapper._columntoproperty.get(col)
  1437. if prop:
  1438. dict_[prop.key] = row_value
  1439. if refresh_flush:
  1440. load_evt_attrs.append(prop.key)
  1441. for c in prefetch_cols:
  1442. if c.key in params and c in mapper._columntoproperty:
  1443. dict_[mapper._columntoproperty[c].key] = params[c.key]
  1444. if refresh_flush:
  1445. load_evt_attrs.append(mapper._columntoproperty[c].key)
  1446. if refresh_flush and load_evt_attrs:
  1447. mapper.class_manager.dispatch.refresh_flush(
  1448. state, uowtransaction, load_evt_attrs
  1449. )
  1450. if isupdate and value_params:
  1451. # explicitly suit the use case specified by
  1452. # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
  1453. # database which are set to themselves in order to do a version bump.
  1454. postfetch_cols.extend(
  1455. [
  1456. col
  1457. for col in value_params
  1458. if col.primary_key and col not in returning_cols
  1459. ]
  1460. )
  1461. if postfetch_cols:
  1462. state._expire_attributes(
  1463. state.dict,
  1464. [
  1465. mapper._columntoproperty[c].key
  1466. for c in postfetch_cols
  1467. if c in mapper._columntoproperty
  1468. ],
  1469. )
  1470. # synchronize newly inserted ids from one table to the next
  1471. # TODO: this still goes a little too often. would be nice to
  1472. # have definitive list of "columns that changed" here
  1473. for m, equated_pairs in mapper._table_to_equated[table]:
  1474. sync.populate(
  1475. state,
  1476. m,
  1477. state,
  1478. m,
  1479. equated_pairs,
  1480. uowtransaction,
  1481. mapper.passive_updates,
  1482. )
  1483. def _postfetch_bulk_save(mapper, dict_, table):
  1484. for m, equated_pairs in mapper._table_to_equated[table]:
  1485. sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
  1486. def _connections_for_states(base_mapper, uowtransaction, states):
  1487. """Return an iterator of (state, state.dict, mapper, connection).
  1488. The states are sorted according to _sort_states, then paired
  1489. with the connection they should be using for the given
  1490. unit of work transaction.
  1491. """
  1492. # if session has a connection callable,
  1493. # organize individual states with the connection
  1494. # to use for update
  1495. if uowtransaction.session.connection_callable:
  1496. connection_callable = uowtransaction.session.connection_callable
  1497. else:
  1498. connection = uowtransaction.transaction.connection(base_mapper)
  1499. connection_callable = None
  1500. for state in _sort_states(base_mapper, states):
  1501. if connection_callable:
  1502. connection = connection_callable(base_mapper, state.obj())
  1503. mapper = state.manager.mapper
  1504. yield state, state.dict, mapper, connection
  1505. def _sort_states(mapper, states):
  1506. pending = set(states)
  1507. persistent = set(s for s in pending if s.key is not None)
  1508. pending.difference_update(persistent)
  1509. try:
  1510. persistent_sorted = sorted(
  1511. persistent, key=mapper._persistent_sortkey_fn
  1512. )
  1513. except TypeError as err:
  1514. util.raise_(
  1515. sa_exc.InvalidRequestError(
  1516. "Could not sort objects by primary key; primary key "
  1517. "values must be sortable in Python (was: %s)" % err
  1518. ),
  1519. replace_context=err,
  1520. )
  1521. return (
  1522. sorted(pending, key=operator.attrgetter("insert_order"))
  1523. + persistent_sorted
  1524. )
  1525. _EMPTY_DICT = util.immutabledict()
  1526. class BulkUDCompileState(CompileState):
  1527. class default_update_options(Options):
  1528. _synchronize_session = "evaluate"
  1529. _autoflush = True
  1530. _subject_mapper = None
  1531. _resolved_values = _EMPTY_DICT
  1532. _resolved_keys_as_propnames = _EMPTY_DICT
  1533. _value_evaluators = _EMPTY_DICT
  1534. _matched_objects = None
  1535. _matched_rows = None
  1536. _refresh_identity_token = None
  1537. @classmethod
  1538. def orm_pre_session_exec(
  1539. cls,
  1540. session,
  1541. statement,
  1542. params,
  1543. execution_options,
  1544. bind_arguments,
  1545. is_reentrant_invoke,
  1546. ):
  1547. if is_reentrant_invoke:
  1548. return statement, execution_options
  1549. (
  1550. update_options,
  1551. execution_options,
  1552. ) = BulkUDCompileState.default_update_options.from_execution_options(
  1553. "_sa_orm_update_options",
  1554. {"synchronize_session"},
  1555. execution_options,
  1556. statement._execution_options,
  1557. )
  1558. sync = update_options._synchronize_session
  1559. if sync is not None:
  1560. if sync not in ("evaluate", "fetch", False):
  1561. raise sa_exc.ArgumentError(
  1562. "Valid strategies for session synchronization "
  1563. "are 'evaluate', 'fetch', False"
  1564. )
  1565. bind_arguments["clause"] = statement
  1566. try:
  1567. plugin_subject = statement._propagate_attrs["plugin_subject"]
  1568. except KeyError:
  1569. assert False, "statement had 'orm' plugin but no plugin_subject"
  1570. else:
  1571. bind_arguments["mapper"] = plugin_subject.mapper
  1572. update_options += {"_subject_mapper": plugin_subject.mapper}
  1573. if update_options._autoflush:
  1574. session._autoflush()
  1575. statement = statement._annotate(
  1576. {"synchronize_session": update_options._synchronize_session}
  1577. )
  1578. # this stage of the execution is called before the do_orm_execute event
  1579. # hook. meaning for an extension like horizontal sharding, this step
  1580. # happens before the extension splits out into multiple backends and
  1581. # runs only once. if we do pre_sync_fetch, we execute a SELECT
  1582. # statement, which the horizontal sharding extension splits amongst the
  1583. # shards and combines the results together.
  1584. if update_options._synchronize_session == "evaluate":
  1585. update_options = cls._do_pre_synchronize_evaluate(
  1586. session,
  1587. statement,
  1588. params,
  1589. execution_options,
  1590. bind_arguments,
  1591. update_options,
  1592. )
  1593. elif update_options._synchronize_session == "fetch":
  1594. update_options = cls._do_pre_synchronize_fetch(
  1595. session,
  1596. statement,
  1597. params,
  1598. execution_options,
  1599. bind_arguments,
  1600. update_options,
  1601. )
  1602. return (
  1603. statement,
  1604. util.immutabledict(execution_options).union(
  1605. {"_sa_orm_update_options": update_options}
  1606. ),
  1607. )
  1608. @classmethod
  1609. def orm_setup_cursor_result(
  1610. cls,
  1611. session,
  1612. statement,
  1613. params,
  1614. execution_options,
  1615. bind_arguments,
  1616. result,
  1617. ):
  1618. # this stage of the execution is called after the
  1619. # do_orm_execute event hook. meaning for an extension like
  1620. # horizontal sharding, this step happens *within* the horizontal
  1621. # sharding event handler which calls session.execute() re-entrantly
  1622. # and will occur for each backend individually.
  1623. # the sharding extension then returns its own merged result from the
  1624. # individual ones we return here.
  1625. update_options = execution_options["_sa_orm_update_options"]
  1626. if update_options._synchronize_session == "evaluate":
  1627. cls._do_post_synchronize_evaluate(session, result, update_options)
  1628. elif update_options._synchronize_session == "fetch":
  1629. cls._do_post_synchronize_fetch(session, result, update_options)
  1630. return result
  1631. @classmethod
  1632. def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
  1633. """Apply extra criteria filtering.
  1634. For all distinct single-table-inheritance mappers represented in the
  1635. table being updated or deleted, produce additional WHERE criteria such
  1636. that only the appropriate subtypes are selected from the total results.
  1637. Additionally, add WHERE criteria originating from LoaderCriteriaOptions
  1638. collected from the statement.
  1639. """
  1640. return_crit = ()
  1641. adapter = ext_info._adapter if ext_info.is_aliased_class else None
  1642. if (
  1643. "additional_entity_criteria",
  1644. ext_info.mapper,
  1645. ) in global_attributes:
  1646. return_crit += tuple(
  1647. ae._resolve_where_criteria(ext_info)
  1648. for ae in global_attributes[
  1649. ("additional_entity_criteria", ext_info.mapper)
  1650. ]
  1651. if ae.include_aliases or ae.entity is ext_info
  1652. )
  1653. if ext_info.mapper._single_table_criterion is not None:
  1654. return_crit += (ext_info.mapper._single_table_criterion,)
  1655. if adapter:
  1656. return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
  1657. return return_crit
  1658. @classmethod
  1659. def _do_pre_synchronize_evaluate(
  1660. cls,
  1661. session,
  1662. statement,
  1663. params,
  1664. execution_options,
  1665. bind_arguments,
  1666. update_options,
  1667. ):
  1668. mapper = update_options._subject_mapper
  1669. target_cls = mapper.class_
  1670. value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
  1671. try:
  1672. evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
  1673. crit = ()
  1674. if statement._where_criteria:
  1675. crit += statement._where_criteria
  1676. global_attributes = {}
  1677. for opt in statement._with_options:
  1678. if opt._is_criteria_option:
  1679. opt.get_global_criteria(global_attributes)
  1680. if global_attributes:
  1681. crit += cls._adjust_for_extra_criteria(
  1682. global_attributes, mapper
  1683. )
  1684. if crit:
  1685. eval_condition = evaluator_compiler.process(*crit)
  1686. else:
  1687. def eval_condition(obj):
  1688. return True
  1689. except evaluator.UnevaluatableError as err:
  1690. util.raise_(
  1691. sa_exc.InvalidRequestError(
  1692. 'Could not evaluate current criteria in Python: "%s". '
  1693. "Specify 'fetch' or False for the "
  1694. "synchronize_session execution option." % err
  1695. ),
  1696. from_=err,
  1697. )
  1698. if statement.__visit_name__ == "lambda_element":
  1699. # ._resolved is called on every LambdaElement in order to
  1700. # generate the cache key, so this access does not add
  1701. # additional expense
  1702. effective_statement = statement._resolved
  1703. else:
  1704. effective_statement = statement
  1705. if effective_statement.__visit_name__ == "update":
  1706. resolved_values = cls._get_resolved_values(
  1707. mapper, effective_statement
  1708. )
  1709. value_evaluators = {}
  1710. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1711. mapper, resolved_values
  1712. )
  1713. for key, value in resolved_keys_as_propnames:
  1714. try:
  1715. _evaluator = evaluator_compiler.process(
  1716. coercions.expect(roles.ExpressionElementRole, value)
  1717. )
  1718. except evaluator.UnevaluatableError:
  1719. pass
  1720. else:
  1721. value_evaluators[key] = _evaluator
  1722. # TODO: detect when the where clause is a trivial primary key match.
  1723. matched_objects = [
  1724. state.obj()
  1725. for state in session.identity_map.all_states()
  1726. if state.mapper.isa(mapper)
  1727. and not state.expired
  1728. and eval_condition(state.obj())
  1729. and (
  1730. update_options._refresh_identity_token is None
  1731. # TODO: coverage for the case where horizontal sharding
  1732. # invokes an update() or delete() given an explicit identity
  1733. # token up front
  1734. or state.identity_token
  1735. == update_options._refresh_identity_token
  1736. )
  1737. ]
  1738. return update_options + {
  1739. "_matched_objects": matched_objects,
  1740. "_value_evaluators": value_evaluators,
  1741. "_resolved_keys_as_propnames": resolved_keys_as_propnames,
  1742. }
  1743. @classmethod
  1744. def _get_resolved_values(cls, mapper, statement):
  1745. if statement._multi_values:
  1746. return []
  1747. elif statement._ordered_values:
  1748. return list(statement._ordered_values)
  1749. elif statement._values:
  1750. return list(statement._values.items())
  1751. else:
  1752. return []
  1753. @classmethod
  1754. def _resolved_keys_as_propnames(cls, mapper, resolved_values):
  1755. values = []
  1756. for k, v in resolved_values:
  1757. if isinstance(k, attributes.QueryableAttribute):
  1758. values.append((k.key, v))
  1759. continue
  1760. elif hasattr(k, "__clause_element__"):
  1761. k = k.__clause_element__()
  1762. if mapper and isinstance(k, expression.ColumnElement):
  1763. try:
  1764. attr = mapper._columntoproperty[k]
  1765. except orm_exc.UnmappedColumnError:
  1766. pass
  1767. else:
  1768. values.append((attr.key, v))
  1769. else:
  1770. raise sa_exc.InvalidRequestError(
  1771. "Invalid expression type: %r" % k
  1772. )
  1773. return values
  1774. @classmethod
  1775. def _do_pre_synchronize_fetch(
  1776. cls,
  1777. session,
  1778. statement,
  1779. params,
  1780. execution_options,
  1781. bind_arguments,
  1782. update_options,
  1783. ):
  1784. mapper = update_options._subject_mapper
  1785. select_stmt = (
  1786. select(*(mapper.primary_key + (mapper.select_identity_token,)))
  1787. .select_from(mapper)
  1788. .options(*statement._with_options)
  1789. )
  1790. select_stmt._where_criteria = statement._where_criteria
  1791. def skip_for_full_returning(orm_context):
  1792. bind = orm_context.session.get_bind(**orm_context.bind_arguments)
  1793. if bind.dialect.full_returning:
  1794. return _result.null_result()
  1795. else:
  1796. return None
  1797. result = session.execute(
  1798. select_stmt,
  1799. params,
  1800. execution_options,
  1801. bind_arguments,
  1802. _add_event=skip_for_full_returning,
  1803. )
  1804. matched_rows = result.fetchall()
  1805. value_evaluators = _EMPTY_DICT
  1806. if statement.__visit_name__ == "lambda_element":
  1807. # ._resolved is called on every LambdaElement in order to
  1808. # generate the cache key, so this access does not add
  1809. # additional expense
  1810. effective_statement = statement._resolved
  1811. else:
  1812. effective_statement = statement
  1813. if effective_statement.__visit_name__ == "update":
  1814. target_cls = mapper.class_
  1815. evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
  1816. resolved_values = cls._get_resolved_values(
  1817. mapper, effective_statement
  1818. )
  1819. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1820. mapper, resolved_values
  1821. )
  1822. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1823. mapper, resolved_values
  1824. )
  1825. value_evaluators = {}
  1826. for key, value in resolved_keys_as_propnames:
  1827. try:
  1828. _evaluator = evaluator_compiler.process(
  1829. coercions.expect(roles.ExpressionElementRole, value)
  1830. )
  1831. except evaluator.UnevaluatableError:
  1832. pass
  1833. else:
  1834. value_evaluators[key] = _evaluator
  1835. else:
  1836. resolved_keys_as_propnames = _EMPTY_DICT
  1837. return update_options + {
  1838. "_value_evaluators": value_evaluators,
  1839. "_matched_rows": matched_rows,
  1840. "_resolved_keys_as_propnames": resolved_keys_as_propnames,
  1841. }
  1842. class ORMDMLState:
  1843. @classmethod
  1844. def get_entity_description(cls, statement):
  1845. ext_info = statement.table._annotations["parententity"]
  1846. mapper = ext_info.mapper
  1847. if ext_info.is_aliased_class:
  1848. _label_name = ext_info.name
  1849. else:
  1850. _label_name = mapper.class_.__name__
  1851. return {
  1852. "name": _label_name,
  1853. "type": mapper.class_,
  1854. "expr": ext_info.entity,
  1855. "entity": ext_info.entity,
  1856. "table": mapper.local_table,
  1857. }
  1858. @classmethod
  1859. def get_returning_column_descriptions(cls, statement):
  1860. def _ent_for_col(c):
  1861. return c._annotations.get("parententity", None)
  1862. def _attr_for_col(c, ent):
  1863. if ent is None:
  1864. return c
  1865. proxy_key = c._annotations.get("proxy_key", None)
  1866. if not proxy_key:
  1867. return c
  1868. else:
  1869. return getattr(ent.entity, proxy_key, c)
  1870. return [
  1871. {
  1872. "name": c.key,
  1873. "type": c.type,
  1874. "expr": _attr_for_col(c, ent),
  1875. "aliased": ent.is_aliased_class,
  1876. "entity": ent.entity,
  1877. }
  1878. for c, ent in [
  1879. (c, _ent_for_col(c)) for c in statement._all_selected_columns
  1880. ]
  1881. ]
  1882. @CompileState.plugin_for("orm", "insert")
  1883. class ORMInsert(ORMDMLState, InsertDMLState):
  1884. @classmethod
  1885. def orm_pre_session_exec(
  1886. cls,
  1887. session,
  1888. statement,
  1889. params,
  1890. execution_options,
  1891. bind_arguments,
  1892. is_reentrant_invoke,
  1893. ):
  1894. bind_arguments["clause"] = statement
  1895. try:
  1896. plugin_subject = statement._propagate_attrs["plugin_subject"]
  1897. except KeyError:
  1898. assert False, "statement had 'orm' plugin but no plugin_subject"
  1899. else:
  1900. bind_arguments["mapper"] = plugin_subject.mapper
  1901. return (
  1902. statement,
  1903. util.immutabledict(execution_options),
  1904. )
  1905. @classmethod
  1906. def orm_setup_cursor_result(
  1907. cls,
  1908. session,
  1909. statement,
  1910. params,
  1911. execution_options,
  1912. bind_arguments,
  1913. result,
  1914. ):
  1915. return result
  1916. @CompileState.plugin_for("orm", "update")
  1917. class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
  1918. @classmethod
  1919. def create_for_statement(cls, statement, compiler, **kw):
  1920. self = cls.__new__(cls)
  1921. ext_info = statement.table._annotations["parententity"]
  1922. self.mapper = mapper = ext_info.mapper
  1923. self.extra_criteria_entities = {}
  1924. self._resolved_values = cls._get_resolved_values(mapper, statement)
  1925. extra_criteria_attributes = {}
  1926. for opt in statement._with_options:
  1927. if opt._is_criteria_option:
  1928. opt.get_global_criteria(extra_criteria_attributes)
  1929. if not statement._preserve_parameter_order and statement._values:
  1930. self._resolved_values = dict(self._resolved_values)
  1931. new_stmt = sql.Update.__new__(sql.Update)
  1932. new_stmt.__dict__.update(statement.__dict__)
  1933. new_stmt.table = mapper.local_table
  1934. # note if the statement has _multi_values, these
  1935. # are passed through to the new statement, which will then raise
  1936. # InvalidRequestError because UPDATE doesn't support multi_values
  1937. # right now.
  1938. if statement._ordered_values:
  1939. new_stmt._ordered_values = self._resolved_values
  1940. elif statement._values:
  1941. new_stmt._values = self._resolved_values
  1942. new_crit = cls._adjust_for_extra_criteria(
  1943. extra_criteria_attributes, mapper
  1944. )
  1945. if new_crit:
  1946. new_stmt = new_stmt.where(*new_crit)
  1947. # if we are against a lambda statement we might not be the
  1948. # topmost object that received per-execute annotations
  1949. if (
  1950. compiler._annotations.get("synchronize_session", None) == "fetch"
  1951. and compiler.dialect.full_returning
  1952. ):
  1953. if new_stmt._returning:
  1954. raise sa_exc.InvalidRequestError(
  1955. "Can't use synchronize_session='fetch' "
  1956. "with explicit returning()"
  1957. )
  1958. new_stmt = new_stmt.returning(*mapper.primary_key)
  1959. UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
  1960. return self
  1961. @classmethod
  1962. def _get_crud_kv_pairs(cls, statement, kv_iterator):
  1963. plugin_subject = statement._propagate_attrs["plugin_subject"]
  1964. core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
  1965. if not plugin_subject or not plugin_subject.mapper:
  1966. return core_get_crud_kv_pairs(statement, kv_iterator)
  1967. mapper = plugin_subject.mapper
  1968. values = []
  1969. for k, v in kv_iterator:
  1970. k = coercions.expect(roles.DMLColumnRole, k)
  1971. if isinstance(k, util.string_types):
  1972. desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
  1973. if desc is NO_VALUE:
  1974. values.append(
  1975. (
  1976. k,
  1977. coercions.expect(
  1978. roles.ExpressionElementRole,
  1979. v,
  1980. type_=sqltypes.NullType(),
  1981. is_crud=True,
  1982. ),
  1983. )
  1984. )
  1985. else:
  1986. values.extend(
  1987. core_get_crud_kv_pairs(
  1988. statement, desc._bulk_update_tuples(v)
  1989. )
  1990. )
  1991. elif "entity_namespace" in k._annotations:
  1992. k_anno = k._annotations
  1993. attr = _entity_namespace_key(
  1994. k_anno["entity_namespace"], k_anno["proxy_key"]
  1995. )
  1996. values.extend(
  1997. core_get_crud_kv_pairs(
  1998. statement, attr._bulk_update_tuples(v)
  1999. )
  2000. )
  2001. else:
  2002. values.append(
  2003. (
  2004. k,
  2005. coercions.expect(
  2006. roles.ExpressionElementRole,
  2007. v,
  2008. type_=sqltypes.NullType(),
  2009. is_crud=True,
  2010. ),
  2011. )
  2012. )
  2013. return values
  2014. @classmethod
  2015. def _do_post_synchronize_evaluate(cls, session, result, update_options):
  2016. states = set()
  2017. evaluated_keys = list(update_options._value_evaluators.keys())
  2018. values = update_options._resolved_keys_as_propnames
  2019. attrib = set(k for k, v in values)
  2020. for obj in update_options._matched_objects:
  2021. state, dict_ = (
  2022. attributes.instance_state(obj),
  2023. attributes.instance_dict(obj),
  2024. )
  2025. # the evaluated states were gathered across all identity tokens.
  2026. # however the post_sync events are called per identity token,
  2027. # so filter.
  2028. if (
  2029. update_options._refresh_identity_token is not None
  2030. and state.identity_token
  2031. != update_options._refresh_identity_token
  2032. ):
  2033. continue
  2034. # only evaluate unmodified attributes
  2035. to_evaluate = state.unmodified.intersection(evaluated_keys)
  2036. for key in to_evaluate:
  2037. if key in dict_:
  2038. dict_[key] = update_options._value_evaluators[key](obj)
  2039. state.manager.dispatch.refresh(state, None, to_evaluate)
  2040. state._commit(dict_, list(to_evaluate))
  2041. to_expire = attrib.intersection(dict_).difference(to_evaluate)
  2042. if to_expire:
  2043. state._expire_attributes(dict_, to_expire)
  2044. states.add(state)
  2045. session._register_altered(states)
  2046. @classmethod
  2047. def _do_post_synchronize_fetch(cls, session, result, update_options):
  2048. target_mapper = update_options._subject_mapper
  2049. states = set()
  2050. evaluated_keys = list(update_options._value_evaluators.keys())
  2051. if result.returns_rows:
  2052. matched_rows = [
  2053. tuple(row) + (update_options._refresh_identity_token,)
  2054. for row in result.all()
  2055. ]
  2056. else:
  2057. matched_rows = update_options._matched_rows
  2058. objs = [
  2059. session.identity_map[identity_key]
  2060. for identity_key in [
  2061. target_mapper.identity_key_from_primary_key(
  2062. list(primary_key),
  2063. identity_token=identity_token,
  2064. )
  2065. for primary_key, identity_token in [
  2066. (row[0:-1], row[-1]) for row in matched_rows
  2067. ]
  2068. if update_options._refresh_identity_token is None
  2069. or identity_token == update_options._refresh_identity_token
  2070. ]
  2071. if identity_key in session.identity_map
  2072. ]
  2073. values = update_options._resolved_keys_as_propnames
  2074. attrib = set(k for k, v in values)
  2075. for obj in objs:
  2076. state, dict_ = (
  2077. attributes.instance_state(obj),
  2078. attributes.instance_dict(obj),
  2079. )
  2080. to_evaluate = state.unmodified.intersection(evaluated_keys)
  2081. for key in to_evaluate:
  2082. if key in dict_:
  2083. dict_[key] = update_options._value_evaluators[key](obj)
  2084. state.manager.dispatch.refresh(state, None, to_evaluate)
  2085. state._commit(dict_, list(to_evaluate))
  2086. to_expire = attrib.intersection(dict_).difference(to_evaluate)
  2087. if to_expire:
  2088. state._expire_attributes(dict_, to_expire)
  2089. states.add(state)
  2090. session._register_altered(states)
  2091. @CompileState.plugin_for("orm", "delete")
  2092. class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
  2093. @classmethod
  2094. def create_for_statement(cls, statement, compiler, **kw):
  2095. self = cls.__new__(cls)
  2096. ext_info = statement.table._annotations["parententity"]
  2097. self.mapper = mapper = ext_info.mapper
  2098. self.extra_criteria_entities = {}
  2099. extra_criteria_attributes = {}
  2100. for opt in statement._with_options:
  2101. if opt._is_criteria_option:
  2102. opt.get_global_criteria(extra_criteria_attributes)
  2103. new_crit = cls._adjust_for_extra_criteria(
  2104. extra_criteria_attributes, mapper
  2105. )
  2106. if new_crit:
  2107. statement = statement.where(*new_crit)
  2108. if (
  2109. mapper
  2110. and compiler._annotations.get("synchronize_session", None)
  2111. == "fetch"
  2112. and compiler.dialect.full_returning
  2113. ):
  2114. statement = statement.returning(*mapper.primary_key)
  2115. DeleteDMLState.__init__(self, statement, compiler, **kw)
  2116. return self
  2117. @classmethod
  2118. def _do_post_synchronize_evaluate(cls, session, result, update_options):
  2119. session._remove_newly_deleted(
  2120. [
  2121. attributes.instance_state(obj)
  2122. for obj in update_options._matched_objects
  2123. ]
  2124. )
  2125. @classmethod
  2126. def _do_post_synchronize_fetch(cls, session, result, update_options):
  2127. target_mapper = update_options._subject_mapper
  2128. if result.returns_rows:
  2129. matched_rows = [
  2130. tuple(row) + (update_options._refresh_identity_token,)
  2131. for row in result.all()
  2132. ]
  2133. else:
  2134. matched_rows = update_options._matched_rows
  2135. for row in matched_rows:
  2136. primary_key = row[0:-1]
  2137. identity_token = row[-1]
  2138. # TODO: inline this and call remove_newly_deleted
  2139. # once
  2140. identity_key = target_mapper.identity_key_from_primary_key(
  2141. list(primary_key),
  2142. identity_token=identity_token,
  2143. )
  2144. if identity_key in session.identity_map:
  2145. session._remove_newly_deleted(
  2146. [
  2147. attributes.instance_state(
  2148. session.identity_map[identity_key]
  2149. )
  2150. ]
  2151. )