db.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing,
  13. # software distributed under the License is distributed on an
  14. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. # KIND, either express or implied. See the License for the
  16. # specific language governing permissions and limitations
  17. # under the License.
  18. from __future__ import annotations
  19. import collections.abc
  20. import contextlib
  21. import enum
  22. import itertools
  23. import json
  24. import logging
  25. import os
  26. import sys
  27. import time
  28. import warnings
  29. from dataclasses import dataclass
  30. from tempfile import gettempdir
  31. from typing import (
  32. TYPE_CHECKING,
  33. Any,
  34. Callable,
  35. Generator,
  36. Iterable,
  37. Iterator,
  38. Protocol,
  39. Sequence,
  40. TypeVar,
  41. overload,
  42. )
  43. import attrs
  44. from sqlalchemy import (
  45. Table,
  46. and_,
  47. column,
  48. delete,
  49. exc,
  50. func,
  51. inspect,
  52. literal,
  53. or_,
  54. select,
  55. table,
  56. text,
  57. tuple_,
  58. )
  59. import airflow
  60. from airflow import settings
  61. from airflow.configuration import conf
  62. from airflow.exceptions import AirflowException
  63. from airflow.models import import_all_models
  64. from airflow.utils import helpers
  65. # TODO: remove create_session once we decide to break backward compatibility
  66. from airflow.utils.session import NEW_SESSION, create_session, provide_session # noqa: F401
  67. from airflow.utils.task_instance_session import get_current_task_instance_session
  68. if TYPE_CHECKING:
  69. from alembic.runtime.environment import EnvironmentContext
  70. from alembic.script import ScriptDirectory
  71. from sqlalchemy.engine import Row
  72. from sqlalchemy.orm import Query, Session
  73. from sqlalchemy.sql.elements import ClauseElement, TextClause
  74. from sqlalchemy.sql.selectable import Select
  75. from airflow.models.connection import Connection
  76. from airflow.typing_compat import Self
  77. # TODO: Import this from sqlalchemy.orm instead when switching to SQLA 2.
  78. # https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.MappedClassProtocol
  79. class MappedClassProtocol(Protocol):
  80. """Protocol for SQLALchemy model base."""
  81. __tablename__: str
  82. T = TypeVar("T")
  83. log = logging.getLogger(__name__)
  84. _REVISION_HEADS_MAP = {
  85. "2.0.0": "e959f08ac86c",
  86. "2.0.1": "82b7c48c147f",
  87. "2.0.2": "2e42bb497a22",
  88. "2.1.0": "a13f7613ad25",
  89. "2.1.3": "97cdd93827b8",
  90. "2.1.4": "ccde3e26fe78",
  91. "2.2.0": "7b2661a43ba3",
  92. "2.2.3": "be2bfac3da23",
  93. "2.2.4": "587bdf053233",
  94. "2.3.0": "b1b348e02d07",
  95. "2.3.1": "1de7bc13c950",
  96. "2.3.2": "3c94c427fdf6",
  97. "2.3.3": "f5fcbda3e651",
  98. "2.4.0": "ecb43d2a1842",
  99. "2.4.2": "b0d31815b5a6",
  100. "2.4.3": "e07f49787c9d",
  101. "2.5.0": "290244fb8b83",
  102. "2.6.0": "98ae134e6fff",
  103. "2.6.2": "c804e5c76e3e",
  104. "2.7.0": "405de8318b3a",
  105. "2.8.0": "10b52ebd31f7",
  106. "2.8.1": "88344c1d9134",
  107. "2.9.0": "1949afb29106",
  108. "2.9.2": "686269002441",
  109. "2.10.0": "22ed7efa9da2",
  110. "2.10.3": "5f2621c13b39",
  111. }
  112. def _format_airflow_moved_table_name(source_table, version, category):
  113. return "__".join([settings.AIRFLOW_MOVED_TABLE_PREFIX, version.replace(".", "_"), category, source_table])
  114. @provide_session
  115. def merge_conn(conn: Connection, session: Session = NEW_SESSION):
  116. """Add new Connection."""
  117. if not session.scalar(select(1).where(conn.__class__.conn_id == conn.conn_id)):
  118. session.add(conn)
  119. session.commit()
  120. @provide_session
  121. def add_default_pool_if_not_exists(session: Session = NEW_SESSION):
  122. """Add default pool if it does not exist."""
  123. from airflow.models.pool import Pool
  124. if not Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session):
  125. default_pool = Pool(
  126. pool=Pool.DEFAULT_POOL_NAME,
  127. slots=conf.getint(section="core", key="default_pool_task_slot_count"),
  128. description="Default pool",
  129. include_deferred=False,
  130. )
  131. session.add(default_pool)
  132. session.commit()
  133. @provide_session
  134. def create_default_connections(session: Session = NEW_SESSION):
  135. """Create default Airflow connections."""
  136. from airflow.models.connection import Connection
  137. merge_conn(
  138. Connection(
  139. conn_id="airflow_db",
  140. conn_type="mysql",
  141. host="mysql",
  142. login="root",
  143. password="",
  144. schema="airflow",
  145. ),
  146. session,
  147. )
  148. merge_conn(
  149. Connection(
  150. conn_id="athena_default",
  151. conn_type="athena",
  152. ),
  153. session,
  154. )
  155. merge_conn(
  156. Connection(
  157. conn_id="aws_default",
  158. conn_type="aws",
  159. ),
  160. session,
  161. )
  162. merge_conn(
  163. Connection(
  164. conn_id="azure_batch_default",
  165. conn_type="azure_batch",
  166. login="<ACCOUNT_NAME>",
  167. password="",
  168. extra="""{"account_url": "<ACCOUNT_URL>"}""",
  169. )
  170. )
  171. merge_conn(
  172. Connection(
  173. conn_id="azure_cosmos_default",
  174. conn_type="azure_cosmos",
  175. extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }',
  176. ),
  177. session,
  178. )
  179. merge_conn(
  180. Connection(
  181. conn_id="azure_data_explorer_default",
  182. conn_type="azure_data_explorer",
  183. host="https://<CLUSTER>.kusto.windows.net",
  184. extra="""{"auth_method": "<AAD_APP | AAD_APP_CERT | AAD_CREDS | AAD_DEVICE>",
  185. "tenant": "<TENANT ID>", "certificate": "<APPLICATION PEM CERTIFICATE>",
  186. "thumbprint": "<APPLICATION CERTIFICATE THUMBPRINT>"}""",
  187. ),
  188. session,
  189. )
  190. merge_conn(
  191. Connection(
  192. conn_id="azure_data_lake_default",
  193. conn_type="azure_data_lake",
  194. extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }',
  195. ),
  196. session,
  197. )
  198. merge_conn(
  199. Connection(
  200. conn_id="azure_default",
  201. conn_type="azure",
  202. ),
  203. session,
  204. )
  205. merge_conn(
  206. Connection(
  207. conn_id="cassandra_default",
  208. conn_type="cassandra",
  209. host="cassandra",
  210. port=9042,
  211. ),
  212. session,
  213. )
  214. merge_conn(
  215. Connection(
  216. conn_id="databricks_default",
  217. conn_type="databricks",
  218. host="localhost",
  219. ),
  220. session,
  221. )
  222. merge_conn(
  223. Connection(
  224. conn_id="dingding_default",
  225. conn_type="http",
  226. host="",
  227. password="",
  228. ),
  229. session,
  230. )
  231. merge_conn(
  232. Connection(
  233. conn_id="drill_default",
  234. conn_type="drill",
  235. host="localhost",
  236. port=8047,
  237. extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}',
  238. ),
  239. session,
  240. )
  241. merge_conn(
  242. Connection(
  243. conn_id="druid_broker_default",
  244. conn_type="druid",
  245. host="druid-broker",
  246. port=8082,
  247. extra='{"endpoint": "druid/v2/sql"}',
  248. ),
  249. session,
  250. )
  251. merge_conn(
  252. Connection(
  253. conn_id="druid_ingest_default",
  254. conn_type="druid",
  255. host="druid-overlord",
  256. port=8081,
  257. extra='{"endpoint": "druid/indexer/v1/task"}',
  258. ),
  259. session,
  260. )
  261. merge_conn(
  262. Connection(
  263. conn_id="elasticsearch_default",
  264. conn_type="elasticsearch",
  265. host="localhost",
  266. schema="http",
  267. port=9200,
  268. ),
  269. session,
  270. )
  271. merge_conn(
  272. Connection(
  273. conn_id="emr_default",
  274. conn_type="emr",
  275. extra="""
  276. { "Name": "default_job_flow_name",
  277. "LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
  278. "ReleaseLabel": "emr-4.6.0",
  279. "Instances": {
  280. "Ec2KeyName": "mykey",
  281. "Ec2SubnetId": "somesubnet",
  282. "InstanceGroups": [
  283. {
  284. "Name": "Master nodes",
  285. "Market": "ON_DEMAND",
  286. "InstanceRole": "MASTER",
  287. "InstanceType": "r3.2xlarge",
  288. "InstanceCount": 1
  289. },
  290. {
  291. "Name": "Core nodes",
  292. "Market": "ON_DEMAND",
  293. "InstanceRole": "CORE",
  294. "InstanceType": "r3.2xlarge",
  295. "InstanceCount": 1
  296. }
  297. ],
  298. "TerminationProtected": false,
  299. "KeepJobFlowAliveWhenNoSteps": false
  300. },
  301. "Applications":[
  302. { "Name": "Spark" }
  303. ],
  304. "VisibleToAllUsers": true,
  305. "JobFlowRole": "EMR_EC2_DefaultRole",
  306. "ServiceRole": "EMR_DefaultRole",
  307. "Tags": [
  308. {
  309. "Key": "app",
  310. "Value": "analytics"
  311. },
  312. {
  313. "Key": "environment",
  314. "Value": "development"
  315. }
  316. ]
  317. }
  318. """,
  319. ),
  320. session,
  321. )
  322. merge_conn(
  323. Connection(
  324. conn_id="facebook_default",
  325. conn_type="facebook_social",
  326. extra="""
  327. { "account_id": "<AD_ACCOUNT_ID>",
  328. "app_id": "<FACEBOOK_APP_ID>",
  329. "app_secret": "<FACEBOOK_APP_SECRET>",
  330. "access_token": "<FACEBOOK_AD_ACCESS_TOKEN>"
  331. }
  332. """,
  333. ),
  334. session,
  335. )
  336. merge_conn(
  337. Connection(
  338. conn_id="fs_default",
  339. conn_type="fs",
  340. extra='{"path": "/"}',
  341. ),
  342. session,
  343. )
  344. merge_conn(
  345. Connection(
  346. conn_id="ftp_default",
  347. conn_type="ftp",
  348. host="localhost",
  349. port=21,
  350. login="airflow",
  351. password="airflow",
  352. extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
  353. ),
  354. session,
  355. )
  356. merge_conn(
  357. Connection(
  358. conn_id="google_cloud_default",
  359. conn_type="google_cloud_platform",
  360. schema="default",
  361. ),
  362. session,
  363. )
  364. merge_conn(
  365. Connection(
  366. conn_id="hive_cli_default",
  367. conn_type="hive_cli",
  368. port=10000,
  369. host="localhost",
  370. extra='{"use_beeline": true, "auth": ""}',
  371. schema="default",
  372. ),
  373. session,
  374. )
  375. merge_conn(
  376. Connection(
  377. conn_id="hiveserver2_default",
  378. conn_type="hiveserver2",
  379. host="localhost",
  380. schema="default",
  381. port=10000,
  382. ),
  383. session,
  384. )
  385. merge_conn(
  386. Connection(
  387. conn_id="http_default",
  388. conn_type="http",
  389. host="https://www.httpbin.org/",
  390. ),
  391. session,
  392. )
  393. merge_conn(
  394. Connection(
  395. conn_id="iceberg_default",
  396. conn_type="iceberg",
  397. host="https://api.iceberg.io/ws/v1",
  398. ),
  399. session,
  400. )
  401. merge_conn(Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050))
  402. merge_conn(
  403. Connection(
  404. conn_id="kafka_default",
  405. conn_type="kafka",
  406. extra=json.dumps({"bootstrap.servers": "broker:29092", "group.id": "my-group"}),
  407. ),
  408. session,
  409. )
  410. merge_conn(
  411. Connection(
  412. conn_id="kubernetes_default",
  413. conn_type="kubernetes",
  414. ),
  415. session,
  416. )
  417. merge_conn(
  418. Connection(
  419. conn_id="kylin_default",
  420. conn_type="kylin",
  421. host="localhost",
  422. port=7070,
  423. login="ADMIN",
  424. password="KYLIN",
  425. ),
  426. session,
  427. )
  428. merge_conn(
  429. Connection(
  430. conn_id="leveldb_default",
  431. conn_type="leveldb",
  432. host="localhost",
  433. ),
  434. session,
  435. )
  436. merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session)
  437. merge_conn(
  438. Connection(
  439. conn_id="local_mysql",
  440. conn_type="mysql",
  441. host="localhost",
  442. login="airflow",
  443. password="airflow",
  444. schema="airflow",
  445. ),
  446. session,
  447. )
  448. merge_conn(
  449. Connection(
  450. conn_id="metastore_default",
  451. conn_type="hive_metastore",
  452. host="localhost",
  453. extra='{"authMechanism": "PLAIN"}',
  454. port=9083,
  455. ),
  456. session,
  457. )
  458. merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session)
  459. merge_conn(
  460. Connection(
  461. conn_id="mssql_default",
  462. conn_type="mssql",
  463. host="localhost",
  464. port=1433,
  465. ),
  466. session,
  467. )
  468. merge_conn(
  469. Connection(
  470. conn_id="mysql_default",
  471. conn_type="mysql",
  472. login="root",
  473. schema="airflow",
  474. host="mysql",
  475. ),
  476. session,
  477. )
  478. merge_conn(
  479. Connection(
  480. conn_id="opsgenie_default",
  481. conn_type="http",
  482. host="",
  483. password="",
  484. ),
  485. session,
  486. )
  487. merge_conn(
  488. Connection(
  489. conn_id="oracle_default",
  490. conn_type="oracle",
  491. host="localhost",
  492. login="root",
  493. password="password",
  494. schema="schema",
  495. port=1521,
  496. ),
  497. session,
  498. )
  499. merge_conn(
  500. Connection(
  501. conn_id="oss_default",
  502. conn_type="oss",
  503. extra="""{
  504. "auth_type": "AK",
  505. "access_key_id": "<ACCESS_KEY_ID>",
  506. "access_key_secret": "<ACCESS_KEY_SECRET>",
  507. "region": "<YOUR_OSS_REGION>"}
  508. """,
  509. ),
  510. session,
  511. )
  512. merge_conn(
  513. Connection(
  514. conn_id="pig_cli_default",
  515. conn_type="pig_cli",
  516. schema="default",
  517. ),
  518. session,
  519. )
  520. merge_conn(
  521. Connection(
  522. conn_id="pinot_admin_default",
  523. conn_type="pinot",
  524. host="localhost",
  525. port=9000,
  526. ),
  527. session,
  528. )
  529. merge_conn(
  530. Connection(
  531. conn_id="pinot_broker_default",
  532. conn_type="pinot",
  533. host="localhost",
  534. port=9000,
  535. extra='{"endpoint": "/query", "schema": "http"}',
  536. ),
  537. session,
  538. )
  539. merge_conn(
  540. Connection(
  541. conn_id="postgres_default",
  542. conn_type="postgres",
  543. login="postgres",
  544. password="airflow",
  545. schema="airflow",
  546. host="postgres",
  547. ),
  548. session,
  549. )
  550. merge_conn(
  551. Connection(
  552. conn_id="presto_default",
  553. conn_type="presto",
  554. host="localhost",
  555. schema="hive",
  556. port=3400,
  557. ),
  558. session,
  559. )
  560. merge_conn(
  561. Connection(
  562. conn_id="qdrant_default",
  563. conn_type="qdrant",
  564. host="qdrant",
  565. port=6333,
  566. ),
  567. session,
  568. )
  569. merge_conn(
  570. Connection(
  571. conn_id="redis_default",
  572. conn_type="redis",
  573. host="redis",
  574. port=6379,
  575. extra='{"db": 0}',
  576. ),
  577. session,
  578. )
  579. merge_conn(
  580. Connection(
  581. conn_id="redshift_default",
  582. conn_type="redshift",
  583. extra="""{
  584. "iam": true,
  585. "cluster_identifier": "<REDSHIFT_CLUSTER_IDENTIFIER>",
  586. "port": 5439,
  587. "profile": "default",
  588. "db_user": "awsuser",
  589. "database": "dev",
  590. "region": ""
  591. }""",
  592. ),
  593. session,
  594. )
  595. merge_conn(
  596. Connection(
  597. conn_id="salesforce_default",
  598. conn_type="salesforce",
  599. login="username",
  600. password="password",
  601. extra='{"security_token": "security_token"}',
  602. ),
  603. session,
  604. )
  605. merge_conn(
  606. Connection(
  607. conn_id="segment_default",
  608. conn_type="segment",
  609. extra='{"write_key": "my-segment-write-key"}',
  610. ),
  611. session,
  612. )
  613. merge_conn(
  614. Connection(
  615. conn_id="sftp_default",
  616. conn_type="sftp",
  617. host="localhost",
  618. port=22,
  619. login="airflow",
  620. extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}',
  621. ),
  622. session,
  623. )
  624. merge_conn(
  625. Connection(
  626. conn_id="spark_default",
  627. conn_type="spark",
  628. host="yarn",
  629. extra='{"queue": "root.default"}',
  630. ),
  631. session,
  632. )
  633. merge_conn(
  634. Connection(
  635. conn_id="sqlite_default",
  636. conn_type="sqlite",
  637. host=os.path.join(gettempdir(), "sqlite_default.db"),
  638. ),
  639. session,
  640. )
  641. merge_conn(
  642. Connection(
  643. conn_id="ssh_default",
  644. conn_type="ssh",
  645. host="localhost",
  646. ),
  647. session,
  648. )
  649. merge_conn(
  650. Connection(
  651. conn_id="tableau_default",
  652. conn_type="tableau",
  653. host="https://tableau.server.url",
  654. login="user",
  655. password="password",
  656. extra='{"site_id": "my_site"}',
  657. ),
  658. session,
  659. )
  660. merge_conn(
  661. Connection(
  662. conn_id="tabular_default",
  663. conn_type="tabular",
  664. host="https://api.tabulardata.io/ws/v1",
  665. ),
  666. session,
  667. )
  668. merge_conn(
  669. Connection(
  670. conn_id="teradata_default",
  671. conn_type="teradata",
  672. host="localhost",
  673. login="user",
  674. password="password",
  675. schema="schema",
  676. ),
  677. session,
  678. )
  679. merge_conn(
  680. Connection(
  681. conn_id="trino_default",
  682. conn_type="trino",
  683. host="localhost",
  684. schema="hive",
  685. port=3400,
  686. ),
  687. session,
  688. )
  689. merge_conn(
  690. Connection(
  691. conn_id="vertica_default",
  692. conn_type="vertica",
  693. host="localhost",
  694. port=5433,
  695. ),
  696. session,
  697. )
  698. merge_conn(
  699. Connection(
  700. conn_id="wasb_default",
  701. conn_type="wasb",
  702. extra='{"sas_token": null}',
  703. ),
  704. session,
  705. )
  706. merge_conn(
  707. Connection(
  708. conn_id="webhdfs_default",
  709. conn_type="hdfs",
  710. host="localhost",
  711. port=50070,
  712. ),
  713. session,
  714. )
  715. merge_conn(
  716. Connection(
  717. conn_id="yandexcloud_default",
  718. conn_type="yandexcloud",
  719. schema="default",
  720. ),
  721. session,
  722. )
  723. merge_conn(
  724. Connection(
  725. conn_id="ydb_default",
  726. conn_type="ydb",
  727. host="grpc://localhost",
  728. port=2135,
  729. extra={"database": "/local"},
  730. ),
  731. session,
  732. )
  733. def _get_flask_db(sql_database_uri):
  734. from flask import Flask
  735. from flask_sqlalchemy import SQLAlchemy
  736. from airflow.www.session import AirflowDatabaseSessionInterface
  737. flask_app = Flask(__name__)
  738. flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
  739. flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
  740. db = SQLAlchemy(flask_app)
  741. AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
  742. return db
  743. def _create_db_from_orm(session):
  744. from alembic import command
  745. from airflow.models.base import Base
  746. from airflow.providers.fab.auth_manager.models import Model
  747. def _create_flask_session_tbl(sql_database_uri):
  748. db = _get_flask_db(sql_database_uri)
  749. db.create_all()
  750. with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
  751. engine = session.get_bind().engine
  752. Base.metadata.create_all(engine)
  753. Model.metadata.create_all(engine)
  754. _create_flask_session_tbl(engine.url)
  755. # stamp the migration head
  756. config = _get_alembic_config()
  757. command.stamp(config, "head")
  758. @provide_session
  759. def initdb(session: Session = NEW_SESSION, load_connections: bool = True, use_migration_files: bool = False):
  760. """Initialize Airflow database."""
  761. import_all_models()
  762. db_exists = _get_current_revision(session)
  763. if db_exists or use_migration_files:
  764. upgradedb(session=session, use_migration_files=use_migration_files)
  765. else:
  766. _create_db_from_orm(session=session)
  767. if conf.getboolean("database", "LOAD_DEFAULT_CONNECTIONS") and load_connections:
  768. create_default_connections(session=session)
  769. # Add default pool & sync log_template
  770. add_default_pool_if_not_exists(session=session)
  771. synchronize_log_template(session=session)
  772. def _get_alembic_config():
  773. from alembic.config import Config
  774. package_dir = os.path.dirname(airflow.__file__)
  775. directory = os.path.join(package_dir, "migrations")
  776. alembic_file = conf.get("database", "alembic_ini_file_path")
  777. if os.path.isabs(alembic_file):
  778. config = Config(alembic_file)
  779. else:
  780. config = Config(os.path.join(package_dir, alembic_file))
  781. config.set_main_option("script_location", directory.replace("%", "%%"))
  782. config.set_main_option("sqlalchemy.url", settings.SQL_ALCHEMY_CONN.replace("%", "%%"))
  783. return config
  784. def _get_script_object(config=None) -> ScriptDirectory:
  785. from alembic.script import ScriptDirectory
  786. if not config:
  787. config = _get_alembic_config()
  788. return ScriptDirectory.from_config(config)
  789. def _get_current_revision(session):
  790. from alembic.migration import MigrationContext
  791. conn = session.connection()
  792. migration_ctx = MigrationContext.configure(conn)
  793. return migration_ctx.get_current_revision()
  794. def check_migrations(timeout):
  795. """
  796. Wait for all airflow migrations to complete.
  797. :param timeout: Timeout for the migration in seconds
  798. :return: None
  799. """
  800. timeout = timeout or 1 # run the loop at least 1
  801. with _configured_alembic_environment() as env:
  802. context = env.get_context()
  803. source_heads = None
  804. db_heads = None
  805. for ticker in range(timeout):
  806. source_heads = set(env.script.get_heads())
  807. db_heads = set(context.get_current_heads())
  808. if source_heads == db_heads:
  809. return
  810. time.sleep(1)
  811. log.info("Waiting for migrations... %s second(s)", ticker)
  812. raise TimeoutError(
  813. f"There are still unapplied migrations after {timeout} seconds. Migration"
  814. f"Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}"
  815. )
  816. @contextlib.contextmanager
  817. def _configured_alembic_environment() -> Generator[EnvironmentContext, None, None]:
  818. from alembic.runtime.environment import EnvironmentContext
  819. config = _get_alembic_config()
  820. script = _get_script_object(config)
  821. with EnvironmentContext(
  822. config,
  823. script,
  824. ) as env, settings.engine.connect() as connection:
  825. alembic_logger = logging.getLogger("alembic")
  826. level = alembic_logger.level
  827. alembic_logger.setLevel(logging.WARNING)
  828. env.configure(connection)
  829. alembic_logger.setLevel(level)
  830. yield env
  831. def check_and_run_migrations():
  832. """Check and run migrations if necessary. Only use in a tty."""
  833. with _configured_alembic_environment() as env:
  834. context = env.get_context()
  835. source_heads = set(env.script.get_heads())
  836. db_heads = set(context.get_current_heads())
  837. db_command = None
  838. command_name = None
  839. verb = None
  840. if len(db_heads) < 1:
  841. db_command = initdb
  842. command_name = "init"
  843. verb = "initialize"
  844. elif source_heads != db_heads:
  845. db_command = upgradedb
  846. command_name = "upgrade"
  847. verb = "upgrade"
  848. if sys.stdout.isatty() and verb:
  849. print()
  850. question = f"Please confirm database {verb} (or wait 4 seconds to skip it). Are you sure? [y/N]"
  851. try:
  852. answer = helpers.prompt_with_timeout(question, timeout=4, default=False)
  853. if answer:
  854. try:
  855. db_command()
  856. print(f"DB {verb} done")
  857. except Exception as error:
  858. from airflow.version import version
  859. print(error)
  860. print(
  861. "You still have unapplied migrations. "
  862. f"You may need to {verb} the database by running `airflow db {command_name}`. ",
  863. f"Make sure the command is run using Airflow version {version}.",
  864. file=sys.stderr,
  865. )
  866. sys.exit(1)
  867. except AirflowException:
  868. pass
  869. elif source_heads != db_heads:
  870. from airflow.version import version
  871. print(
  872. f"ERROR: You need to {verb} the database. Please run `airflow db {command_name}`. "
  873. f"Make sure the command is run using Airflow version {version}.",
  874. file=sys.stderr,
  875. )
  876. sys.exit(1)
  877. def _reserialize_dags(*, session: Session) -> None:
  878. from airflow.models.dagbag import DagBag
  879. from airflow.models.serialized_dag import SerializedDagModel
  880. session.execute(delete(SerializedDagModel).execution_options(synchronize_session=False))
  881. dagbag = DagBag(collect_dags=False)
  882. dagbag.collect_dags(only_if_updated=False)
  883. dagbag.sync_to_db(session=session)
  884. @provide_session
  885. def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
  886. """
  887. Synchronize log template configs with table.
  888. This checks if the last row fully matches the current config values, and
  889. insert a new row if not.
  890. """
  891. # NOTE: SELECT queries in this function are INTENTIONALLY written with the
  892. # SQL builder style, not the ORM query API. This avoids configuring the ORM
  893. # unless we need to insert something, speeding up CLI in general.
  894. from airflow.models.tasklog import LogTemplate
  895. metadata = reflect_tables([LogTemplate], session)
  896. log_template_table: Table | None = metadata.tables.get(LogTemplate.__tablename__)
  897. if log_template_table is None:
  898. log.info("Log template table does not exist (added in 2.3.0); skipping log template sync.")
  899. return
  900. filename = conf.get("logging", "log_filename_template")
  901. elasticsearch_id = conf.get("elasticsearch", "log_id_template")
  902. stored = session.execute(
  903. select(
  904. log_template_table.c.filename,
  905. log_template_table.c.elasticsearch_id,
  906. )
  907. .order_by(log_template_table.c.id.desc())
  908. .limit(1)
  909. ).first()
  910. # If we have an empty table, and the default values exist, we will seed the
  911. # table with values from pre 2.3.0, so old logs will still be retrievable.
  912. if not stored:
  913. is_default_log_id = elasticsearch_id == conf.get_default_value("elasticsearch", "log_id_template")
  914. is_default_filename = filename == conf.get_default_value("logging", "log_filename_template")
  915. if is_default_log_id and is_default_filename:
  916. session.add(
  917. LogTemplate(
  918. filename="{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log",
  919. elasticsearch_id="{dag_id}-{task_id}-{execution_date}-{try_number}",
  920. )
  921. )
  922. # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in
  923. # place exists!
  924. pre_upgrade_filename = conf.upgraded_values.get(("logging", "log_filename_template"), filename)
  925. pre_upgrade_elasticsearch_id = conf.upgraded_values.get(
  926. ("elasticsearch", "log_id_template"), elasticsearch_id
  927. )
  928. if pre_upgrade_filename != filename or pre_upgrade_elasticsearch_id != elasticsearch_id:
  929. # The previous non-upgraded value likely won't be the _latest_ value (as after we've recorded the
  930. # recorded the upgraded value it will be second-to-newest), so we'll have to just search which is okay
  931. # as this is a table with a tiny number of rows
  932. row = session.execute(
  933. select(log_template_table.c.id)
  934. .where(
  935. or_(
  936. log_template_table.c.filename == pre_upgrade_filename,
  937. log_template_table.c.elasticsearch_id == pre_upgrade_elasticsearch_id,
  938. )
  939. )
  940. .order_by(log_template_table.c.id.desc())
  941. .limit(1)
  942. ).first()
  943. if not row:
  944. session.add(
  945. LogTemplate(filename=pre_upgrade_filename, elasticsearch_id=pre_upgrade_elasticsearch_id)
  946. )
  947. if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id:
  948. session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))
  949. def check_conn_id_duplicates(session: Session) -> Iterable[str]:
  950. """
  951. Check unique conn_id in connection table.
  952. :param session: session of the sqlalchemy
  953. """
  954. from airflow.models.connection import Connection
  955. try:
  956. dups = session.scalars(
  957. select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
  958. ).all()
  959. except (exc.OperationalError, exc.ProgrammingError):
  960. # fallback if tables hasn't been created yet
  961. session.rollback()
  962. return
  963. if dups:
  964. yield (
  965. "Seems you have non unique conn_id in connection table.\n"
  966. "You have to manage those duplicate connections "
  967. "before upgrading the database.\n"
  968. f"Duplicated conn_id: {dups}"
  969. )
  970. def check_username_duplicates(session: Session) -> Iterable[str]:
  971. """
  972. Check unique username in User & RegisterUser table.
  973. :param session: session of the sqlalchemy
  974. :rtype: str
  975. """
  976. from airflow.providers.fab.auth_manager.models import RegisterUser, User
  977. for model in [User, RegisterUser]:
  978. dups = []
  979. try:
  980. dups = session.execute(
  981. select(model.username) # type: ignore[attr-defined]
  982. .group_by(model.username) # type: ignore[attr-defined]
  983. .having(func.count() > 1)
  984. ).all()
  985. except (exc.OperationalError, exc.ProgrammingError):
  986. # fallback if tables hasn't been created yet
  987. session.rollback()
  988. if dups:
  989. yield (
  990. f"Seems you have mixed case usernames in {model.__table__.name} table.\n" # type: ignore
  991. "You have to rename or delete those mixed case usernames "
  992. "before upgrading the database.\n"
  993. f"usernames with mixed cases: {[dup.username for dup in dups]}"
  994. )
  995. def reflect_tables(tables: list[MappedClassProtocol | str] | None, session):
  996. """
  997. When running checks prior to upgrades, we use reflection to determine current state of the database.
  998. This function gets the current state of each table in the set of models
  999. provided and returns a SqlAlchemy metadata object containing them.
  1000. """
  1001. import sqlalchemy.schema
  1002. bind = session.bind
  1003. metadata = sqlalchemy.schema.MetaData()
  1004. if tables is None:
  1005. metadata.reflect(bind=bind, resolve_fks=False)
  1006. else:
  1007. for tbl in tables:
  1008. try:
  1009. table_name = tbl if isinstance(tbl, str) else tbl.__tablename__
  1010. metadata.reflect(bind=bind, only=[table_name], extend_existing=True, resolve_fks=False)
  1011. except exc.InvalidRequestError:
  1012. continue
  1013. return metadata
  1014. def check_table_for_duplicates(
  1015. *, session: Session, table_name: str, uniqueness: list[str], version: str
  1016. ) -> Iterable[str]:
  1017. """
  1018. Check table for duplicates, given a list of columns which define the uniqueness of the table.
  1019. Usage example:
  1020. .. code-block:: python
  1021. def check_task_fail_for_duplicates(session):
  1022. from airflow.models.taskfail import TaskFail
  1023. metadata = reflect_tables([TaskFail], session)
  1024. task_fail = metadata.tables.get(TaskFail.__tablename__) # type: ignore
  1025. if task_fail is None: # table not there
  1026. return
  1027. if "run_id" in task_fail.columns: # upgrade already applied
  1028. return
  1029. yield from check_table_for_duplicates(
  1030. table_name=task_fail.name,
  1031. uniqueness=["dag_id", "task_id", "execution_date"],
  1032. session=session,
  1033. version="2.3",
  1034. )
  1035. :param table_name: table name to check
  1036. :param uniqueness: uniqueness constraint to evaluate against
  1037. :param session: session of the sqlalchemy
  1038. """
  1039. minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
  1040. try:
  1041. subquery = session.execute(
  1042. select(minimal_table_obj, func.count().label("dupe_count"))
  1043. .group_by(*(text(x) for x in uniqueness))
  1044. .having(func.count() > text("1"))
  1045. .subquery()
  1046. )
  1047. dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count)))
  1048. if not dupe_count:
  1049. # there are no duplicates; nothing to do.
  1050. return
  1051. log.warning("Found %s duplicates in table %s. Will attempt to move them.", dupe_count, table_name)
  1052. metadata = reflect_tables(tables=[table_name], session=session)
  1053. if table_name not in metadata.tables:
  1054. yield f"Table {table_name} does not exist in the database."
  1055. # We can't use the model here since it may differ from the db state due to
  1056. # this function is run prior to migration. Use the reflected table instead.
  1057. table_obj = metadata.tables[table_name]
  1058. _move_duplicate_data_to_new_table(
  1059. session=session,
  1060. source_table=table_obj,
  1061. subquery=subquery,
  1062. uniqueness=uniqueness,
  1063. target_table_name=_format_airflow_moved_table_name(table_name, version, "duplicates"),
  1064. )
  1065. except (exc.OperationalError, exc.ProgrammingError):
  1066. # fallback if `table_name` hasn't been created yet
  1067. session.rollback()
  1068. def check_conn_type_null(session: Session) -> Iterable[str]:
  1069. """
  1070. Check nullable conn_type column in Connection table.
  1071. :param session: session of the sqlalchemy
  1072. """
  1073. from airflow.models.connection import Connection
  1074. try:
  1075. n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
  1076. except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
  1077. # fallback if tables hasn't been created yet
  1078. session.rollback()
  1079. return
  1080. if n_nulls:
  1081. yield (
  1082. "The conn_type column in the connection "
  1083. "table must contain content.\n"
  1084. "Make sure you don't have null "
  1085. "in the conn_type column.\n"
  1086. f"Null conn_type conn_id: {n_nulls}"
  1087. )
  1088. def _format_dangling_error(source_table, target_table, invalid_count, reason):
  1089. noun = "row" if invalid_count == 1 else "rows"
  1090. return (
  1091. f"The {source_table} table has {invalid_count} {noun} {reason}, which "
  1092. f"is invalid. We could not move them out of the way because the "
  1093. f"{target_table} table already exists in your database. Please either "
  1094. f"drop the {target_table} table, or manually delete the invalid rows "
  1095. f"from the {source_table} table."
  1096. )
  1097. def check_run_id_null(session: Session) -> Iterable[str]:
  1098. from airflow.models.dagrun import DagRun
  1099. metadata = reflect_tables([DagRun], session)
  1100. # We can't use the model here since it may differ from the db state due to
  1101. # this function is run prior to migration. Use the reflected table instead.
  1102. dagrun_table = metadata.tables.get(DagRun.__tablename__)
  1103. if dagrun_table is None:
  1104. return
  1105. invalid_dagrun_filter = or_(
  1106. dagrun_table.c.dag_id.is_(None),
  1107. dagrun_table.c.run_id.is_(None),
  1108. dagrun_table.c.execution_date.is_(None),
  1109. )
  1110. invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter))
  1111. if invalid_dagrun_count > 0:
  1112. dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling")
  1113. if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names():
  1114. yield _format_dangling_error(
  1115. source_table=dagrun_table.name,
  1116. target_table=dagrun_dangling_table_name,
  1117. invalid_count=invalid_dagrun_count,
  1118. reason="with a NULL dag_id, run_id, or execution_date",
  1119. )
  1120. return
  1121. bind = session.get_bind()
  1122. dialect_name = bind.dialect.name
  1123. _create_table_as(
  1124. dialect_name=dialect_name,
  1125. source_query=dagrun_table.select(invalid_dagrun_filter),
  1126. target_table_name=dagrun_dangling_table_name,
  1127. source_table_name=dagrun_table.name,
  1128. session=session,
  1129. )
  1130. delete = dagrun_table.delete().where(invalid_dagrun_filter)
  1131. session.execute(delete)
  1132. def _create_table_as(
  1133. *,
  1134. session,
  1135. dialect_name: str,
  1136. source_query: Query,
  1137. target_table_name: str,
  1138. source_table_name: str,
  1139. ):
  1140. """
  1141. Create a new table with rows from query.
  1142. We have to handle CTAS differently for different dialects.
  1143. """
  1144. if dialect_name == "mysql":
  1145. # MySQL with replication needs this split in to two queries, so just do it for all MySQL
  1146. # ERROR 1786 (HY000): Statement violates GTID consistency: CREATE TABLE ... SELECT.
  1147. session.execute(text(f"CREATE TABLE {target_table_name} LIKE {source_table_name}"))
  1148. session.execute(
  1149. text(
  1150. f"INSERT INTO {target_table_name} {source_query.selectable.compile(bind=session.get_bind())}"
  1151. )
  1152. )
  1153. else:
  1154. # Postgres and SQLite both support the same "CREATE TABLE a AS SELECT ..." syntax
  1155. select_table = source_query.selectable.compile(bind=session.get_bind())
  1156. session.execute(text(f"CREATE TABLE {target_table_name} AS {select_table}"))
  1157. def _move_dangling_data_to_new_table(
  1158. session, source_table: Table, source_query: Query, target_table_name: str
  1159. ):
  1160. bind = session.get_bind()
  1161. dialect_name = bind.dialect.name
  1162. # First: Create moved rows from new table
  1163. log.debug("running CTAS for table %s", target_table_name)
  1164. _create_table_as(
  1165. dialect_name=dialect_name,
  1166. source_query=source_query,
  1167. target_table_name=target_table_name,
  1168. source_table_name=source_table.name,
  1169. session=session,
  1170. )
  1171. session.commit()
  1172. target_table = source_table.to_metadata(source_table.metadata, name=target_table_name)
  1173. log.debug("checking whether rows were moved for table %s", target_table_name)
  1174. moved_rows_exist_query = select(1).select_from(target_table).limit(1)
  1175. first_moved_row = session.execute(moved_rows_exist_query).all()
  1176. session.commit()
  1177. if not first_moved_row:
  1178. log.debug("no rows moved; dropping %s", target_table_name)
  1179. # no bad rows were found; drop moved rows table.
  1180. target_table.drop(bind=session.get_bind(), checkfirst=True)
  1181. else:
  1182. log.debug("rows moved; purging from %s", source_table.name)
  1183. if dialect_name == "sqlite":
  1184. pk_cols = source_table.primary_key.columns
  1185. delete = source_table.delete().where(
  1186. tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery())
  1187. )
  1188. else:
  1189. delete = source_table.delete().where(
  1190. and_(col == target_table.c[col.name] for col in source_table.primary_key.columns)
  1191. )
  1192. log.debug(delete.compile())
  1193. session.execute(delete)
  1194. session.commit()
  1195. log.debug("exiting move function")
  1196. def _dangling_against_dag_run(session, source_table, dag_run):
  1197. """Given a source table, we generate a subquery that will return 1 for every row that has a dagrun."""
  1198. source_to_dag_run_join_cond = and_(
  1199. source_table.c.dag_id == dag_run.c.dag_id,
  1200. source_table.c.execution_date == dag_run.c.execution_date,
  1201. )
  1202. return (
  1203. select(*(c.label(c.name) for c in source_table.c))
  1204. .join(dag_run, source_to_dag_run_join_cond, isouter=True)
  1205. .where(dag_run.c.dag_id.is_(None))
  1206. )
  1207. def _dangling_against_task_instance(session, source_table, dag_run, task_instance):
  1208. """
  1209. Given a source table, generate a subquery that will return 1 for every row that has a valid task instance.
  1210. This is used to identify rows that need to be removed from tables prior to adding a TI fk.
  1211. Since this check is applied prior to running the migrations, we have to use different
  1212. query logic depending on which revision the database is at.
  1213. """
  1214. if "run_id" not in task_instance.c:
  1215. # db is < 2.2.0
  1216. dr_join_cond = and_(
  1217. source_table.c.dag_id == dag_run.c.dag_id,
  1218. source_table.c.execution_date == dag_run.c.execution_date,
  1219. )
  1220. ti_join_cond = and_(
  1221. dag_run.c.dag_id == task_instance.c.dag_id,
  1222. dag_run.c.execution_date == task_instance.c.execution_date,
  1223. source_table.c.task_id == task_instance.c.task_id,
  1224. )
  1225. else:
  1226. # db is 2.2.0 <= version < 2.3.0
  1227. dr_join_cond = and_(
  1228. source_table.c.dag_id == dag_run.c.dag_id,
  1229. source_table.c.execution_date == dag_run.c.execution_date,
  1230. )
  1231. ti_join_cond = and_(
  1232. dag_run.c.dag_id == task_instance.c.dag_id,
  1233. dag_run.c.run_id == task_instance.c.run_id,
  1234. source_table.c.task_id == task_instance.c.task_id,
  1235. )
  1236. return (
  1237. select(*(c.label(c.name) for c in source_table.c))
  1238. .outerjoin(dag_run, dr_join_cond)
  1239. .outerjoin(task_instance, ti_join_cond)
  1240. .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None)))
  1241. )
  1242. def _move_duplicate_data_to_new_table(
  1243. session, source_table: Table, subquery: Query, uniqueness: list[str], target_table_name: str
  1244. ):
  1245. """
  1246. When adding a uniqueness constraint we first should ensure that there are no duplicate rows.
  1247. This function accepts a subquery that should return one record for each row with duplicates (e.g.
  1248. a group by with having count(*) > 1). We select from ``source_table`` getting all rows matching the
  1249. subquery result and store in ``target_table_name``. Then to purge the duplicates from the source table,
  1250. we do a DELETE FROM with a join to the target table (which now contains the dupes).
  1251. :param session: sqlalchemy session for metadata db
  1252. :param source_table: table to purge dupes from
  1253. :param subquery: the subquery that returns the duplicate rows
  1254. :param uniqueness: the string list of columns used to define the uniqueness for the table. used in
  1255. building the DELETE FROM join condition.
  1256. :param target_table_name: name of the table in which to park the duplicate rows
  1257. """
  1258. bind = session.get_bind()
  1259. dialect_name = bind.dialect.name
  1260. query = (
  1261. select(*(source_table.c[x.name].label(str(x.name)) for x in source_table.columns))
  1262. .select_from(source_table)
  1263. .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in uniqueness)))
  1264. )
  1265. _create_table_as(
  1266. session=session,
  1267. dialect_name=dialect_name,
  1268. source_query=query,
  1269. target_table_name=target_table_name,
  1270. source_table_name=source_table.name,
  1271. )
  1272. # we must ensure that the CTAS table is created prior to the DELETE step since we have to join to it
  1273. session.commit()
  1274. metadata = reflect_tables([target_table_name], session)
  1275. target_table = metadata.tables[target_table_name]
  1276. where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in uniqueness))
  1277. if dialect_name == "sqlite":
  1278. subq = query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
  1279. delete = source_table.delete().where(column("ROWID").in_(subq))
  1280. else:
  1281. delete = source_table.delete(where_clause)
  1282. session.execute(delete)
  1283. def check_bad_references(session: Session) -> Iterable[str]:
  1284. """
  1285. Go through each table and look for records that can't be mapped to a dag run.
  1286. When we find such "dangling" rows we back them up in a special table and delete them
  1287. from the main table.
  1288. Starting in Airflow 2.2, we began a process of replacing `execution_date` with `run_id` in many tables.
  1289. """
  1290. from airflow.models.dagrun import DagRun
  1291. from airflow.models.renderedtifields import RenderedTaskInstanceFields
  1292. from airflow.models.taskfail import TaskFail
  1293. from airflow.models.taskinstance import TaskInstance
  1294. from airflow.models.taskreschedule import TaskReschedule
  1295. from airflow.models.xcom import XCom
  1296. @dataclass
  1297. class BadReferenceConfig:
  1298. """
  1299. Bad reference config class.
  1300. :param bad_rows_func: function that returns subquery which determines whether bad rows exist
  1301. :param join_tables: table objects referenced in subquery
  1302. :param ref_table: information-only identifier for categorizing the missing ref
  1303. """
  1304. bad_rows_func: Callable
  1305. join_tables: list[str]
  1306. ref_table: str
  1307. missing_dag_run_config = BadReferenceConfig(
  1308. bad_rows_func=_dangling_against_dag_run,
  1309. join_tables=["dag_run"],
  1310. ref_table="dag_run",
  1311. )
  1312. missing_ti_config = BadReferenceConfig(
  1313. bad_rows_func=_dangling_against_task_instance,
  1314. join_tables=["dag_run", "task_instance"],
  1315. ref_table="task_instance",
  1316. )
  1317. models_list: list[tuple[MappedClassProtocol, str, BadReferenceConfig]] = [
  1318. (TaskInstance, "2.2", missing_dag_run_config),
  1319. (TaskReschedule, "2.2", missing_ti_config),
  1320. (RenderedTaskInstanceFields, "2.3", missing_ti_config),
  1321. (TaskFail, "2.3", missing_ti_config),
  1322. (XCom, "2.3", missing_ti_config),
  1323. ]
  1324. metadata = reflect_tables([*(x[0] for x in models_list), DagRun, TaskInstance], session)
  1325. if (
  1326. not metadata.tables
  1327. or metadata.tables.get(DagRun.__tablename__) is None
  1328. or metadata.tables.get(TaskInstance.__tablename__) is None
  1329. ):
  1330. # Key table doesn't exist -- likely empty DB.
  1331. return
  1332. existing_table_names = set(inspect(session.get_bind()).get_table_names())
  1333. errored = False
  1334. for model, change_version, bad_ref_cfg in models_list:
  1335. log.debug("checking model %s", model.__tablename__)
  1336. # We can't use the model here since it may differ from the db state due to
  1337. # this function is run prior to migration. Use the reflected table instead.
  1338. source_table = metadata.tables.get(model.__tablename__) # type: ignore
  1339. if source_table is None:
  1340. continue
  1341. # Migration already applied, don't check again.
  1342. if "run_id" in source_table.columns:
  1343. continue
  1344. func_kwargs = {x: metadata.tables[x] for x in bad_ref_cfg.join_tables}
  1345. bad_rows_query = bad_ref_cfg.bad_rows_func(session, source_table, **func_kwargs)
  1346. dangling_table_name = _format_airflow_moved_table_name(source_table.name, change_version, "dangling")
  1347. if dangling_table_name in existing_table_names:
  1348. invalid_row_count = get_query_count(bad_rows_query, session=session)
  1349. if invalid_row_count:
  1350. yield _format_dangling_error(
  1351. source_table=source_table.name,
  1352. target_table=dangling_table_name,
  1353. invalid_count=invalid_row_count,
  1354. reason=f"without a corresponding {bad_ref_cfg.ref_table} row",
  1355. )
  1356. errored = True
  1357. continue
  1358. log.debug("moving data for table %s", source_table.name)
  1359. _move_dangling_data_to_new_table(
  1360. session,
  1361. source_table,
  1362. bad_rows_query,
  1363. dangling_table_name,
  1364. )
  1365. if errored:
  1366. session.rollback()
  1367. else:
  1368. session.commit()
  1369. @provide_session
  1370. def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
  1371. """:session: session of the sqlalchemy."""
  1372. check_functions: tuple[Callable[..., Iterable[str]], ...] = (
  1373. check_conn_id_duplicates,
  1374. check_conn_type_null,
  1375. check_run_id_null,
  1376. check_bad_references,
  1377. check_username_duplicates,
  1378. )
  1379. for check_fn in check_functions:
  1380. log.debug("running check function %s", check_fn.__name__)
  1381. yield from check_fn(session=session)
  1382. def _offline_migration(migration_func: Callable, config, revision):
  1383. with warnings.catch_warnings():
  1384. warnings.simplefilter("ignore")
  1385. logging.disable(logging.CRITICAL)
  1386. migration_func(config, revision, sql=True)
  1387. logging.disable(logging.NOTSET)
  1388. def print_happy_cat(message):
  1389. if sys.stdout.isatty():
  1390. size = os.get_terminal_size().columns
  1391. else:
  1392. size = 0
  1393. print(message.center(size))
  1394. print("""/\\_/\\""".center(size))
  1395. print("""(='_' )""".center(size))
  1396. print("""(,(") (")""".center(size))
  1397. print("""^^^""".center(size))
  1398. return
  1399. def _revision_greater(config, this_rev, base_rev):
  1400. # Check if there is history between the revisions and the start revision
  1401. # This ensures that the revisions are above `min_revision`
  1402. script = _get_script_object(config)
  1403. try:
  1404. list(script.revision_map.iterate_revisions(upper=this_rev, lower=base_rev))
  1405. return True
  1406. except Exception:
  1407. return False
  1408. def _revisions_above_min_for_offline(config, revisions) -> None:
  1409. """
  1410. Check that all supplied revision ids are above the minimum revision for the dialect.
  1411. :param config: Alembic config
  1412. :param revisions: list of Alembic revision ids
  1413. :return: None
  1414. """
  1415. dbname = settings.engine.dialect.name
  1416. if dbname == "sqlite":
  1417. raise SystemExit("Offline migration not supported for SQLite.")
  1418. min_version, min_revision = ("2.2.0", "7b2661a43ba3") if dbname == "mssql" else ("2.0.0", "e959f08ac86c")
  1419. # Check if there is history between the revisions and the start revision
  1420. # This ensures that the revisions are above `min_revision`
  1421. for rev in revisions:
  1422. if not _revision_greater(config, rev, min_revision):
  1423. raise ValueError(
  1424. f"Error while checking history for revision range {min_revision}:{rev}. "
  1425. f"Check that {rev} is a valid revision. "
  1426. f"For dialect {dbname!r}, supported revision for offline migration is from {min_revision} "
  1427. f"which corresponds to Airflow {min_version}."
  1428. )
  1429. @provide_session
  1430. def upgradedb(
  1431. *,
  1432. to_revision: str | None = None,
  1433. from_revision: str | None = None,
  1434. show_sql_only: bool = False,
  1435. reserialize_dags: bool = True,
  1436. session: Session = NEW_SESSION,
  1437. use_migration_files: bool = False,
  1438. ):
  1439. """
  1440. Upgrades the DB.
  1441. :param to_revision: Optional Alembic revision ID to upgrade *to*.
  1442. If omitted, upgrades to latest revision.
  1443. :param from_revision: Optional Alembic revision ID to upgrade *from*.
  1444. Not compatible with ``sql_only=False``.
  1445. :param show_sql_only: if True, migration statements will be printed but not executed.
  1446. :param session: sqlalchemy session with connection to Airflow metadata database
  1447. :return: None
  1448. """
  1449. if from_revision and not show_sql_only:
  1450. raise AirflowException("`from_revision` only supported with `sql_only=True`.")
  1451. # alembic adds significant import time, so we import it lazily
  1452. if not settings.SQL_ALCHEMY_CONN:
  1453. raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.")
  1454. from alembic import command
  1455. import_all_models()
  1456. config = _get_alembic_config()
  1457. if show_sql_only:
  1458. if not from_revision:
  1459. from_revision = _get_current_revision(session)
  1460. if not to_revision:
  1461. script = _get_script_object()
  1462. to_revision = script.get_current_head()
  1463. if to_revision == from_revision:
  1464. print_happy_cat("No migrations to apply; nothing to do.")
  1465. return
  1466. if not _revision_greater(config, to_revision, from_revision):
  1467. raise ValueError(
  1468. f"Requested *to* revision {to_revision} is older than *from* revision {from_revision}. "
  1469. "Please check your requested versions / revisions."
  1470. )
  1471. _revisions_above_min_for_offline(config=config, revisions=[from_revision, to_revision])
  1472. _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}")
  1473. return # only running sql; our job is done
  1474. errors_seen = False
  1475. for err in _check_migration_errors(session=session):
  1476. if not errors_seen:
  1477. log.error("Automatic migration is not available")
  1478. errors_seen = True
  1479. log.error("%s", err)
  1480. if errors_seen:
  1481. exit(1)
  1482. if not to_revision and not _get_current_revision(session=session) and not use_migration_files:
  1483. # Don't load default connections
  1484. # New DB; initialize and exit
  1485. initdb(session=session, load_connections=False)
  1486. return
  1487. with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
  1488. import sqlalchemy.pool
  1489. previous_revision = _get_current_revision(session=session)
  1490. log.info("Creating tables")
  1491. val = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE")
  1492. try:
  1493. # Reconfigure the ORM to use _EXACTLY_ one connection, otherwise some db engines hang forever
  1494. # trying to ALTER TABLEs
  1495. os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1"
  1496. settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool)
  1497. command.upgrade(config, revision=to_revision or "heads")
  1498. finally:
  1499. if val is None:
  1500. os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE")
  1501. else:
  1502. os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val
  1503. settings.reconfigure_orm()
  1504. current_revision = _get_current_revision(session=session)
  1505. if reserialize_dags and current_revision != previous_revision:
  1506. _reserialize_dags(session=session)
  1507. add_default_pool_if_not_exists(session=session)
  1508. synchronize_log_template(session=session)
  1509. @provide_session
  1510. def resetdb(session: Session = NEW_SESSION, skip_init: bool = False, use_migration_files: bool = False):
  1511. """Clear out the database."""
  1512. if not settings.engine:
  1513. raise RuntimeError("The settings.engine must be set. This is a critical assertion")
  1514. log.info("Dropping tables that exist")
  1515. import_all_models()
  1516. connection = settings.engine.connect()
  1517. with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin():
  1518. drop_airflow_models(connection)
  1519. drop_airflow_moved_tables(connection)
  1520. if not skip_init:
  1521. initdb(session=session, use_migration_files=use_migration_files)
  1522. @provide_session
  1523. def bootstrap_dagbag(session: Session = NEW_SESSION):
  1524. from airflow.models.dag import DAG
  1525. from airflow.models.dagbag import DagBag
  1526. dagbag = DagBag()
  1527. # Save DAGs in the ORM
  1528. dagbag.sync_to_db(session=session)
  1529. # Deactivate the unknown ones
  1530. DAG.deactivate_unknown_dags(dagbag.dags.keys(), session=session)
  1531. @provide_session
  1532. def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: Session = NEW_SESSION):
  1533. """
  1534. Downgrade the airflow metastore schema to a prior version.
  1535. :param to_revision: The alembic revision to downgrade *to*.
  1536. :param show_sql_only: if True, print sql statements but do not run them
  1537. :param from_revision: if supplied, alembic revision to dawngrade *from*. This may only
  1538. be used in conjunction with ``sql=True`` because if we actually run the commands,
  1539. we should only downgrade from the *current* revision.
  1540. :param session: sqlalchemy session for connection to airflow metadata database
  1541. """
  1542. if from_revision and not show_sql_only:
  1543. raise ValueError(
  1544. "`from_revision` can't be combined with `sql=False`. When actually "
  1545. "applying a downgrade (instead of just generating sql), we always "
  1546. "downgrade from current revision."
  1547. )
  1548. if not settings.SQL_ALCHEMY_CONN:
  1549. raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.")
  1550. # alembic adds significant import time, so we import it lazily
  1551. from alembic import command
  1552. log.info("Attempting downgrade to revision %s", to_revision)
  1553. config = _get_alembic_config()
  1554. with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
  1555. if show_sql_only:
  1556. log.warning("Generating sql scripts for manual migration.")
  1557. if not from_revision:
  1558. from_revision = _get_current_revision(session)
  1559. revision_range = f"{from_revision}:{to_revision}"
  1560. _offline_migration(command.downgrade, config=config, revision=revision_range)
  1561. else:
  1562. log.info("Applying downgrade migrations.")
  1563. command.downgrade(config, revision=to_revision, sql=show_sql_only)
  1564. def drop_airflow_models(connection):
  1565. """
  1566. Drop all airflow models.
  1567. :param connection: SQLAlchemy Connection
  1568. :return: None
  1569. """
  1570. from airflow.models.base import Base
  1571. from airflow.providers.fab.auth_manager.models import Model
  1572. Base.metadata.drop_all(connection)
  1573. Model.metadata.drop_all(connection)
  1574. db = _get_flask_db(connection.engine.url)
  1575. db.drop_all()
  1576. # alembic adds significant import time, so we import it lazily
  1577. from alembic.migration import MigrationContext
  1578. migration_ctx = MigrationContext.configure(connection)
  1579. version = migration_ctx._version
  1580. if inspect(connection).has_table(version.name):
  1581. version.drop(connection)
  1582. def drop_airflow_moved_tables(connection):
  1583. from airflow.models.base import Base
  1584. from airflow.settings import AIRFLOW_MOVED_TABLE_PREFIX
  1585. tables = set(inspect(connection).get_table_names())
  1586. to_delete = [Table(x, Base.metadata) for x in tables if x.startswith(AIRFLOW_MOVED_TABLE_PREFIX)]
  1587. for tbl in to_delete:
  1588. tbl.drop(settings.engine, checkfirst=False)
  1589. Base.metadata.remove(tbl)
  1590. @provide_session
  1591. def check(session: Session = NEW_SESSION):
  1592. """
  1593. Check if the database works.
  1594. :param session: session of the sqlalchemy
  1595. """
  1596. session.execute(text("select 1 as is_alive;"))
  1597. log.info("Connection successful.")
  1598. @enum.unique
  1599. class DBLocks(enum.IntEnum):
  1600. """
  1601. Cross-db Identifiers for advisory global database locks.
  1602. Postgres uses int64 lock ids so we use the integer value, MySQL uses names, so we
  1603. call ``str()`, which is implemented using the ``_name_`` field.
  1604. """
  1605. MIGRATIONS = enum.auto()
  1606. SCHEDULER_CRITICAL_SECTION = enum.auto()
  1607. def __str__(self):
  1608. return f"airflow_{self._name_}"
  1609. @contextlib.contextmanager
  1610. def create_global_lock(
  1611. session: Session,
  1612. lock: DBLocks,
  1613. lock_timeout: int = 1800,
  1614. ) -> Generator[None, None, None]:
  1615. """Contextmanager that will create and teardown a global db lock."""
  1616. conn = session.get_bind().connect()
  1617. dialect = conn.dialect
  1618. try:
  1619. if dialect.name == "postgresql":
  1620. conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout})
  1621. conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value})
  1622. elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
  1623. conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout})
  1624. yield
  1625. finally:
  1626. if dialect.name == "postgresql":
  1627. conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
  1628. (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone()
  1629. if not unlocked:
  1630. raise RuntimeError("Error releasing DB lock!")
  1631. elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
  1632. conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
  1633. def compare_type(context, inspected_column, metadata_column, inspected_type, metadata_type):
  1634. """
  1635. Compare types between ORM and DB .
  1636. return False if the metadata_type is the same as the inspected_type
  1637. or None to allow the default implementation to compare these
  1638. types. a return value of True means the two types do not
  1639. match and should result in a type change operation.
  1640. """
  1641. if context.dialect.name == "mysql":
  1642. from sqlalchemy import String
  1643. from sqlalchemy.dialects import mysql
  1644. if isinstance(inspected_type, mysql.VARCHAR) and isinstance(metadata_type, String):
  1645. # This is a hack to get around MySQL VARCHAR collation
  1646. # not being possible to change from utf8_bin to utf8mb3_bin.
  1647. # We only make sure lengths are the same
  1648. if inspected_type.length != metadata_type.length:
  1649. return True
  1650. return False
  1651. return None
  1652. def compare_server_default(
  1653. context, inspected_column, metadata_column, inspected_default, metadata_default, rendered_metadata_default
  1654. ):
  1655. """
  1656. Compare server defaults between ORM and DB .
  1657. return True if the defaults are different, False if not, or None to allow the default implementation
  1658. to compare these defaults
  1659. In SQLite: task_instance.map_index & task_reschedule.map_index
  1660. are not comparing accurately. Sometimes they are equal, sometimes they are not.
  1661. Alembic warned that this feature has varied accuracy depending on backends.
  1662. See: (https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.
  1663. environment.EnvironmentContext.configure.params.compare_server_default)
  1664. """
  1665. dialect_name = context.connection.dialect.name
  1666. if dialect_name in ["sqlite"]:
  1667. return False
  1668. if (
  1669. dialect_name == "mysql"
  1670. and metadata_column.name == "pool_slots"
  1671. and metadata_column.table.name == "task_instance"
  1672. ):
  1673. # We removed server_default value in ORM to avoid expensive migration
  1674. # (it was removed in postgres DB in migration head 7b2661a43ba3 ).
  1675. # As a side note, server default value here was only actually needed for the migration
  1676. # where we added the column in the first place -- now that it exists and all
  1677. # existing rows are populated with a value this server default is never used.
  1678. return False
  1679. return None
  1680. def get_sqla_model_classes():
  1681. """
  1682. Get all SQLAlchemy class mappers.
  1683. SQLAlchemy < 1.4 does not support registry.mappers so we use
  1684. try/except to handle it.
  1685. """
  1686. from airflow.models.base import Base
  1687. try:
  1688. return [mapper.class_ for mapper in Base.registry.mappers]
  1689. except AttributeError:
  1690. return Base._decl_class_registry.values()
  1691. def get_query_count(query_stmt: Select, *, session: Session) -> int:
  1692. """
  1693. Get count of a query.
  1694. A SELECT COUNT() FROM is issued against the subquery built from the
  1695. given statement. The ORDER BY clause is stripped from the statement
  1696. since it's unnecessary for COUNT, and can impact query planning and
  1697. degrade performance.
  1698. :meta private:
  1699. """
  1700. count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
  1701. return session.scalar(count_stmt)
  1702. def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
  1703. """
  1704. Check whether there is at least one row matching a query.
  1705. A SELECT 1 FROM is issued against the subquery built from the given
  1706. statement. The ORDER BY clause is stripped from the statement since it's
  1707. unnecessary, and can impact query planning and degrade performance.
  1708. :meta private:
  1709. """
  1710. count_stmt = select(literal(True)).select_from(query_stmt.order_by(None).subquery())
  1711. return session.scalar(count_stmt)
  1712. def exists_query(*where: ClauseElement, session: Session) -> bool:
  1713. """
  1714. Check whether there is at least one row matching given clauses.
  1715. This does a SELECT 1 WHERE ... LIMIT 1 and check the result.
  1716. :meta private:
  1717. """
  1718. stmt = select(literal(True)).where(*where).limit(1)
  1719. return session.scalar(stmt) is not None
  1720. @attrs.define(slots=True)
  1721. class LazySelectSequence(Sequence[T]):
  1722. """
  1723. List-like interface to lazily access a database model query.
  1724. The intended use case is inside a task execution context, where we manage an
  1725. active SQLAlchemy session in the background.
  1726. This is an abstract base class. Each use case should subclass, and implement
  1727. the following static methods:
  1728. * ``_rebuild_select`` is called when a lazy sequence is unpickled. Since it
  1729. is not easy to pickle SQLAlchemy constructs, this class serializes the
  1730. SELECT statements into plain text to storage. This method is called on
  1731. deserialization to convert the textual clause back into an ORM SELECT.
  1732. * ``_process_row`` is called when an item is accessed. The lazy sequence
  1733. uses ``session.execute()`` to fetch rows from the database, and this
  1734. method should know how to process each row into a value.
  1735. :meta private:
  1736. """
  1737. _select_asc: ClauseElement
  1738. _select_desc: ClauseElement
  1739. _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session)
  1740. _len: int | None = attrs.field(init=False, default=None)
  1741. @classmethod
  1742. def from_select(
  1743. cls,
  1744. select: Select,
  1745. *,
  1746. order_by: Sequence[ClauseElement],
  1747. session: Session | None = None,
  1748. ) -> Self:
  1749. s1 = select
  1750. for col in order_by:
  1751. s1 = s1.order_by(col.asc())
  1752. s2 = select
  1753. for col in order_by:
  1754. s2 = s2.order_by(col.desc())
  1755. return cls(s1, s2, session=session or get_current_task_instance_session())
  1756. @staticmethod
  1757. def _rebuild_select(stmt: TextClause) -> Select:
  1758. """
  1759. Rebuild a textual statement into an ORM-configured SELECT statement.
  1760. This should do something like ``select(field).from_statement(stmt)`` to
  1761. reconfigure ORM information to the textual SQL statement.
  1762. """
  1763. raise NotImplementedError
  1764. @staticmethod
  1765. def _process_row(row: Row) -> T:
  1766. """Process a SELECT-ed row into the end value."""
  1767. raise NotImplementedError
  1768. def __repr__(self) -> str:
  1769. counter = "item" if (length := len(self)) == 1 else "items"
  1770. return f"LazySelectSequence([{length} {counter}])"
  1771. def __str__(self) -> str:
  1772. counter = "item" if (length := len(self)) == 1 else "items"
  1773. return f"LazySelectSequence([{length} {counter}])"
  1774. def __getstate__(self) -> Any:
  1775. # We don't want to go to the trouble of serializing SQLAlchemy objects.
  1776. # Converting the statement into a SQL string is the best we can get.
  1777. # The literal_binds compile argument inlines all the values into the SQL
  1778. # string to simplify cross-process commuinication as much as possible.
  1779. # Theoratically we can do the same for count(), but I think it should be
  1780. # performant enough to calculate only that eagerly.
  1781. s1 = str(self._select_asc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
  1782. s2 = str(self._select_desc.compile(self._session.get_bind(), compile_kwargs={"literal_binds": True}))
  1783. return (s1, s2, len(self))
  1784. def __setstate__(self, state: Any) -> None:
  1785. s1, s2, self._len = state
  1786. self._select_asc = self._rebuild_select(text(s1))
  1787. self._select_desc = self._rebuild_select(text(s2))
  1788. self._session = get_current_task_instance_session()
  1789. def __bool__(self) -> bool:
  1790. return check_query_exists(self._select_asc, session=self._session)
  1791. def __eq__(self, other: Any) -> bool:
  1792. if not isinstance(other, collections.abc.Sequence):
  1793. return NotImplemented
  1794. z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
  1795. return all(x == y for x, y in z)
  1796. def __reversed__(self) -> Iterator[T]:
  1797. return iter(self._process_row(r) for r in self._session.execute(self._select_desc))
  1798. def __iter__(self) -> Iterator[T]:
  1799. return iter(self._process_row(r) for r in self._session.execute(self._select_asc))
  1800. def __len__(self) -> int:
  1801. if self._len is None:
  1802. self._len = get_query_count(self._select_asc, session=self._session)
  1803. return self._len
  1804. @overload
  1805. def __getitem__(self, key: int) -> T: ...
  1806. @overload
  1807. def __getitem__(self, key: slice) -> Sequence[T]: ...
  1808. def __getitem__(self, key: int | slice) -> T | Sequence[T]:
  1809. if isinstance(key, int):
  1810. if key >= 0:
  1811. stmt = self._select_asc.offset(key)
  1812. else:
  1813. stmt = self._select_desc.offset(-1 - key)
  1814. if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None:
  1815. raise IndexError(key)
  1816. return self._process_row(row)
  1817. elif isinstance(key, slice):
  1818. # This implements the slicing syntax. We want to optimize negative
  1819. # slicing (e.g. seq[-10:]) by not doing an additional COUNT query
  1820. # if possible. We can do this unless the start and stop have
  1821. # different signs (i.e. one is positive and another negative).
  1822. start, stop, reverse = _coerce_slice(key)
  1823. if start >= 0:
  1824. if stop is None:
  1825. stmt = self._select_asc.offset(start)
  1826. elif stop >= 0:
  1827. stmt = self._select_asc.slice(start, stop)
  1828. else:
  1829. stmt = self._select_asc.slice(start, len(self) + stop)
  1830. rows = [self._process_row(row) for row in self._session.execute(stmt)]
  1831. if reverse:
  1832. rows.reverse()
  1833. else:
  1834. if stop is None:
  1835. stmt = self._select_desc.limit(-start)
  1836. elif stop < 0:
  1837. stmt = self._select_desc.slice(-stop, -start)
  1838. else:
  1839. stmt = self._select_desc.slice(len(self) - stop, -start)
  1840. rows = [self._process_row(row) for row in self._session.execute(stmt)]
  1841. if not reverse:
  1842. rows.reverse()
  1843. return rows
  1844. raise TypeError(f"Sequence indices must be integers or slices, not {type(key).__name__}")
  1845. def _coerce_index(value: Any) -> int | None:
  1846. """
  1847. Check slice attribute's type and convert it to int.
  1848. See CPython documentation on this:
  1849. https://docs.python.org/3/reference/datamodel.html#object.__index__
  1850. """
  1851. if value is None or isinstance(value, int):
  1852. return value
  1853. if (index := getattr(value, "__index__", None)) is not None:
  1854. return index()
  1855. raise TypeError("slice indices must be integers or None or have an __index__ method")
  1856. def _coerce_slice(key: slice) -> tuple[int, int | None, bool]:
  1857. """
  1858. Check slice content and convert it for SQL.
  1859. See CPython documentation on this:
  1860. https://docs.python.org/3/reference/datamodel.html#slice-objects
  1861. """
  1862. if key.step is None or key.step == 1:
  1863. reverse = False
  1864. elif key.step == -1:
  1865. reverse = True
  1866. else:
  1867. raise ValueError("non-trivial slice step not supported")
  1868. return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse