orm.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. from collections import OrderedDict
  2. from functools import partial
  3. from inspect import isclass
  4. from operator import attrgetter
  5. import sqlalchemy as sa
  6. from sqlalchemy.engine.interfaces import Dialect
  7. from sqlalchemy.ext.hybrid import hybrid_property
  8. from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
  9. from sqlalchemy.orm.attributes import InstrumentedAttribute
  10. from sqlalchemy.orm.exc import UnmappedInstanceError
  11. try:
  12. from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
  13. except ImportError: # SQLAlchemy <1.4
  14. from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
  15. from sqlalchemy.orm.session import object_session
  16. from sqlalchemy.orm.util import AliasedInsp
  17. from ..utils import is_sequence
  18. def get_class_by_table(base, table, data=None):
  19. """
  20. Return declarative class associated with given table. If no class is found
  21. this function returns `None`. If multiple classes were found (polymorphic
  22. cases) additional `data` parameter can be given to hint which class
  23. to return.
  24. ::
  25. class User(Base):
  26. __tablename__ = 'entity'
  27. id = sa.Column(sa.Integer, primary_key=True)
  28. name = sa.Column(sa.String)
  29. get_class_by_table(Base, User.__table__) # User class
  30. This function also supports models using single table inheritance.
  31. Additional data paratemer should be provided in these case.
  32. ::
  33. class Entity(Base):
  34. __tablename__ = 'entity'
  35. id = sa.Column(sa.Integer, primary_key=True)
  36. name = sa.Column(sa.String)
  37. type = sa.Column(sa.String)
  38. __mapper_args__ = {
  39. 'polymorphic_on': type,
  40. 'polymorphic_identity': 'entity'
  41. }
  42. class User(Entity):
  43. __mapper_args__ = {
  44. 'polymorphic_identity': 'user'
  45. }
  46. # Entity class
  47. get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
  48. # User class
  49. get_class_by_table(Base, Entity.__table__, {'type': 'user'})
  50. :param base: Declarative model base
  51. :param table: SQLAlchemy Table object
  52. :param data: Data row to determine the class in polymorphic scenarios
  53. :return: Declarative class or None.
  54. """
  55. found_classes = {
  56. c for c in _get_class_registry(base).values()
  57. if hasattr(c, '__table__') and c.__table__ is table
  58. }
  59. if len(found_classes) > 1:
  60. if not data:
  61. raise ValueError(
  62. "Multiple declarative classes found for table '{}'. "
  63. "Please provide data parameter for this function to be able "
  64. "to determine polymorphic scenarios.".format(
  65. table.name
  66. )
  67. )
  68. else:
  69. for cls in found_classes:
  70. mapper = sa.inspect(cls)
  71. polymorphic_on = mapper.polymorphic_on.name
  72. if polymorphic_on in data:
  73. if data[polymorphic_on] == mapper.polymorphic_identity:
  74. return cls
  75. raise ValueError(
  76. "Multiple declarative classes found for table '{}'. Given "
  77. "data row does not match any polymorphic identity of the "
  78. "found classes.".format(
  79. table.name
  80. )
  81. )
  82. elif found_classes:
  83. return found_classes.pop()
  84. return None
  85. def get_type(expr):
  86. """
  87. Return the associated type with given Column, InstrumentedAttribute,
  88. ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
  89. For constructs wrapping columns this is the column type. For relationships
  90. this function returns the relationship mapper class.
  91. :param expr:
  92. SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
  93. similar SA construct.
  94. ::
  95. class User(Base):
  96. __tablename__ = 'user'
  97. id = sa.Column(sa.Integer, primary_key=True)
  98. name = sa.Column(sa.String)
  99. class Article(Base):
  100. __tablename__ = 'article'
  101. id = sa.Column(sa.Integer, primary_key=True)
  102. author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
  103. author = sa.orm.relationship(User)
  104. get_type(User.__table__.c.name) # sa.String()
  105. get_type(User.name) # sa.String()
  106. get_type(User.name.property) # sa.String()
  107. get_type(Article.author) # User
  108. .. versionadded: 0.30.9
  109. """
  110. if hasattr(expr, 'type'):
  111. return expr.type
  112. elif isinstance(expr, InstrumentedAttribute):
  113. expr = expr.property
  114. if isinstance(expr, ColumnProperty):
  115. return expr.columns[0].type
  116. elif isinstance(expr, RelationshipProperty):
  117. return expr.mapper.class_
  118. raise TypeError("Couldn't inspect type.")
  119. def cast_if(expression, type_):
  120. """
  121. Produce a CAST expression but only if given expression is not of given type
  122. already.
  123. Assume we have a model with two fields id (Integer) and name (String).
  124. ::
  125. import sqlalchemy as sa
  126. from sqlalchemy_utils import cast_if
  127. cast_if(User.id, sa.Integer) # "user".id
  128. cast_if(User.name, sa.String) # "user".name
  129. cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
  130. This function supports scalar values as well.
  131. ::
  132. cast_if(1, sa.Integer) # 1
  133. cast_if('text', sa.String) # 'text'
  134. cast_if(1, sa.String) # CAST(1 AS TEXT)
  135. :param expression:
  136. A SQL expression, such as a ColumnElement expression or a Python string
  137. which will be coerced into a bound literal value.
  138. :param type_:
  139. A TypeEngine class or instance indicating the type to which the CAST
  140. should apply.
  141. .. versionadded: 0.30.14
  142. """
  143. try:
  144. expr_type = get_type(expression)
  145. except TypeError:
  146. expr_type = expression
  147. check_type = type_().python_type
  148. else:
  149. check_type = type_
  150. return (
  151. sa.cast(expression, type_)
  152. if not isinstance(expr_type, check_type)
  153. else expression
  154. )
  155. def get_column_key(model, column):
  156. """
  157. Return the key for given column in given model.
  158. :param model: SQLAlchemy declarative model object
  159. ::
  160. class User(Base):
  161. __tablename__ = 'user'
  162. id = sa.Column(sa.Integer, primary_key=True)
  163. name = sa.Column('_name', sa.String)
  164. get_column_key(User, User.__table__.c._name) # 'name'
  165. .. versionadded: 0.26.5
  166. .. versionchanged: 0.27.11
  167. Throws UnmappedColumnError instead of ValueError when no property was
  168. found for given column. This is consistent with how SQLAlchemy works.
  169. """
  170. mapper = sa.inspect(model)
  171. try:
  172. return mapper.get_property_by_column(column).key
  173. except sa.orm.exc.UnmappedColumnError:
  174. for key, c in mapper.columns.items():
  175. if c.name == column.name and c.table is column.table:
  176. return key
  177. raise sa.orm.exc.UnmappedColumnError(
  178. 'No column %s is configured on mapper %s...' %
  179. (column, mapper)
  180. )
  181. def get_mapper(mixed):
  182. """
  183. Return related SQLAlchemy Mapper for given SQLAlchemy object.
  184. :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
  185. ::
  186. from sqlalchemy_utils import get_mapper
  187. get_mapper(User)
  188. get_mapper(User())
  189. get_mapper(User.__table__)
  190. get_mapper(User.__mapper__)
  191. get_mapper(sa.orm.aliased(User))
  192. get_mapper(sa.orm.aliased(User.__table__))
  193. Raises:
  194. ValueError: if multiple mappers were found for given argument
  195. .. versionadded: 0.26.1
  196. """
  197. if isinstance(mixed, _MapperEntity):
  198. mixed = mixed.expr
  199. elif isinstance(mixed, sa.Column):
  200. mixed = mixed.table
  201. elif isinstance(mixed, _ColumnEntity):
  202. mixed = mixed.expr
  203. if isinstance(mixed, sa.orm.Mapper):
  204. return mixed
  205. if isinstance(mixed, sa.orm.util.AliasedClass):
  206. return sa.inspect(mixed).mapper
  207. if isinstance(mixed, sa.sql.selectable.Alias):
  208. mixed = mixed.element
  209. if isinstance(mixed, AliasedInsp):
  210. return mixed.mapper
  211. if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
  212. mixed = mixed.class_
  213. if isinstance(mixed, sa.Table):
  214. if hasattr(mapperlib, '_all_registries'):
  215. all_mappers = set()
  216. for mapper_registry in mapperlib._all_registries():
  217. all_mappers.update(mapper_registry.mappers)
  218. else: # SQLAlchemy <1.4
  219. all_mappers = mapperlib._mapper_registry
  220. mappers = [
  221. mapper for mapper in all_mappers
  222. if mixed in mapper.tables
  223. ]
  224. if len(mappers) > 1:
  225. raise ValueError(
  226. "Multiple mappers found for table '%s'." % mixed.name
  227. )
  228. elif not mappers:
  229. raise ValueError(
  230. "Could not get mapper for table '%s'." % mixed.name
  231. )
  232. else:
  233. return mappers[0]
  234. if not isclass(mixed):
  235. mixed = type(mixed)
  236. return sa.inspect(mixed)
  237. def get_bind(obj):
  238. """
  239. Return the bind for given SQLAlchemy Engine / Connection / declarative
  240. model object.
  241. :param obj: SQLAlchemy Engine / Connection / declarative model object
  242. ::
  243. from sqlalchemy_utils import get_bind
  244. get_bind(session) # Connection object
  245. get_bind(user)
  246. """
  247. if hasattr(obj, 'bind'):
  248. conn = obj.bind
  249. else:
  250. try:
  251. conn = object_session(obj).bind
  252. except UnmappedInstanceError:
  253. conn = obj
  254. if not hasattr(conn, 'execute'):
  255. raise TypeError(
  256. 'This method accepts only Session, Engine, Connection and '
  257. 'declarative model objects.'
  258. )
  259. return conn
  260. def get_primary_keys(mixed):
  261. """
  262. Return an OrderedDict of all primary keys for given Table object,
  263. declarative class or declarative class instance.
  264. :param mixed:
  265. SA Table object, SA declarative class or SA declarative class instance
  266. ::
  267. get_primary_keys(User)
  268. get_primary_keys(User())
  269. get_primary_keys(User.__table__)
  270. get_primary_keys(User.__mapper__)
  271. get_primary_keys(sa.orm.aliased(User))
  272. get_primary_keys(sa.orm.aliased(User.__table__))
  273. .. versionchanged: 0.25.3
  274. Made the function return an ordered dictionary instead of generator.
  275. This change was made to support primary key aliases.
  276. Renamed this function to 'get_primary_keys', formerly 'primary_keys'
  277. .. seealso:: :func:`get_columns`
  278. """
  279. return OrderedDict(
  280. (
  281. (key, column) for key, column in get_columns(mixed).items()
  282. if column.primary_key
  283. )
  284. )
  285. def get_tables(mixed):
  286. """
  287. Return a set of tables associated with given SQLAlchemy object.
  288. Let's say we have three classes which use joined table inheritance
  289. TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
  290. ::
  291. get_tables(Article) # set([Table('article', ...), Table('text_item')])
  292. get_tables(Article())
  293. get_tables(Article.__mapper__)
  294. If the TextItem entity is using with_polymorphic='*' then this function
  295. returns all child tables (article and blog_post) as well.
  296. ::
  297. get_tables(TextItem) # set([Table('text_item', ...)], ...])
  298. .. versionadded: 0.26.0
  299. :param mixed:
  300. SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
  301. a SA Alias object wrapping any of these objects.
  302. """
  303. if isinstance(mixed, sa.Table):
  304. return [mixed]
  305. elif isinstance(mixed, sa.Column):
  306. return [mixed.table]
  307. elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
  308. return mixed.parent.tables
  309. elif isinstance(mixed, _ColumnEntity):
  310. mixed = mixed.expr
  311. mapper = get_mapper(mixed)
  312. polymorphic_mappers = get_polymorphic_mappers(mapper)
  313. if polymorphic_mappers:
  314. tables = sum((m.tables for m in polymorphic_mappers), [])
  315. else:
  316. tables = mapper.tables
  317. return tables
  318. def get_columns(mixed):
  319. """
  320. Return a collection of all Column objects for given SQLAlchemy
  321. object.
  322. The type of the collection depends on the type of the object to return the
  323. columns from.
  324. ::
  325. get_columns(User)
  326. get_columns(User())
  327. get_columns(User.__table__)
  328. get_columns(User.__mapper__)
  329. get_columns(sa.orm.aliased(User))
  330. get_columns(sa.orm.alised(User.__table__))
  331. :param mixed:
  332. SA Table object, SA Mapper, SA declarative class, SA declarative class
  333. instance or an alias of any of these objects
  334. """
  335. if isinstance(mixed, sa.sql.selectable.Selectable):
  336. try:
  337. return mixed.selected_columns
  338. except AttributeError: # SQLAlchemy <1.4
  339. return mixed.c
  340. if isinstance(mixed, sa.orm.util.AliasedClass):
  341. return sa.inspect(mixed).mapper.columns
  342. if isinstance(mixed, sa.orm.Mapper):
  343. return mixed.columns
  344. if isinstance(mixed, InstrumentedAttribute):
  345. return mixed.property.columns
  346. if isinstance(mixed, ColumnProperty):
  347. return mixed.columns
  348. if isinstance(mixed, sa.Column):
  349. return [mixed]
  350. if not isclass(mixed):
  351. mixed = mixed.__class__
  352. return sa.inspect(mixed).columns
  353. def table_name(obj):
  354. """
  355. Return table name of given target, declarative class or the
  356. table name where the declarative attribute is bound to.
  357. """
  358. class_ = getattr(obj, 'class_', obj)
  359. try:
  360. return class_.__tablename__
  361. except AttributeError:
  362. pass
  363. try:
  364. return class_.__table__.name
  365. except AttributeError:
  366. pass
  367. def getattrs(obj, attrs):
  368. return map(partial(getattr, obj), attrs)
  369. def quote(mixed, ident):
  370. """
  371. Conditionally quote an identifier.
  372. ::
  373. from sqlalchemy_utils import quote
  374. engine = create_engine('sqlite:///:memory:')
  375. quote(engine, 'order')
  376. # '"order"'
  377. quote(engine, 'some_other_identifier')
  378. # 'some_other_identifier'
  379. :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
  380. :param ident: identifier to conditionally quote
  381. """
  382. if isinstance(mixed, Dialect):
  383. dialect = mixed
  384. else:
  385. dialect = get_bind(mixed).dialect
  386. return dialect.preparer(dialect).quote(ident)
  387. def _get_query_compile_state(query):
  388. if hasattr(query, '_compile_state'):
  389. return query._compile_state()
  390. else: # SQLAlchemy <1.4
  391. return query
  392. def get_polymorphic_mappers(mixed):
  393. if isinstance(mixed, AliasedInsp):
  394. return mixed.with_polymorphic_mappers
  395. else:
  396. return mixed.polymorphic_map.values()
  397. def get_descriptor(entity, attr):
  398. mapper = sa.inspect(entity)
  399. for key, descriptor in get_all_descriptors(mapper).items():
  400. if attr == key:
  401. prop = (
  402. descriptor.property
  403. if hasattr(descriptor, 'property')
  404. else None
  405. )
  406. if isinstance(prop, ColumnProperty):
  407. if isinstance(entity, sa.orm.util.AliasedClass):
  408. for c in mapper.selectable.c:
  409. if c.key == attr:
  410. return c
  411. else:
  412. # If the property belongs to a class that uses
  413. # polymorphic inheritance we have to take into account
  414. # situations where the attribute exists in child class
  415. # but not in parent class.
  416. return getattr(prop.parent.class_, attr)
  417. else:
  418. # Handle synonyms, relationship properties and hybrid
  419. # properties
  420. if isinstance(entity, sa.orm.util.AliasedClass):
  421. return getattr(entity, attr)
  422. try:
  423. return getattr(mapper.class_, attr)
  424. except AttributeError:
  425. pass
  426. def get_all_descriptors(expr):
  427. if isinstance(expr, sa.sql.selectable.Selectable):
  428. return expr.c
  429. insp = sa.inspect(expr)
  430. try:
  431. polymorphic_mappers = get_polymorphic_mappers(insp)
  432. except sa.exc.NoInspectionAvailable:
  433. return get_mapper(expr).all_orm_descriptors
  434. else:
  435. attrs = dict(get_mapper(expr).all_orm_descriptors)
  436. for submapper in polymorphic_mappers:
  437. for key, descriptor in submapper.all_orm_descriptors.items():
  438. if key not in attrs:
  439. attrs[key] = descriptor
  440. return attrs
  441. def get_hybrid_properties(model):
  442. """
  443. Returns a dictionary of hybrid property keys and hybrid properties for
  444. given SQLAlchemy declarative model / mapper.
  445. Consider the following model
  446. ::
  447. from sqlalchemy.ext.hybrid import hybrid_property
  448. class Category(Base):
  449. __tablename__ = 'category'
  450. id = sa.Column(sa.Integer, primary_key=True)
  451. name = sa.Column(sa.Unicode(255))
  452. @hybrid_property
  453. def lowercase_name(self):
  454. return self.name.lower()
  455. @lowercase_name.expression
  456. def lowercase_name(cls):
  457. return sa.func.lower(cls.name)
  458. You can now easily get a list of all hybrid property names
  459. ::
  460. from sqlalchemy_utils import get_hybrid_properties
  461. get_hybrid_properties(Category).keys() # ['lowercase_name']
  462. This function also supports aliased classes
  463. ::
  464. get_hybrid_properties(
  465. sa.orm.aliased(Category)
  466. ).keys() # ['lowercase_name']
  467. .. versionchanged: 0.26.7
  468. This function now returns a dictionary instead of generator
  469. .. versionchanged: 0.30.15
  470. Added support for aliased classes
  471. :param model: SQLAlchemy declarative model or mapper
  472. """
  473. return {
  474. key: prop
  475. for key, prop in get_mapper(model).all_orm_descriptors.items()
  476. if isinstance(prop, hybrid_property)
  477. }
  478. def get_declarative_base(model):
  479. """
  480. Returns the declarative base for given model class.
  481. :param model: SQLAlchemy declarative model
  482. """
  483. for parent in model.__bases__:
  484. try:
  485. parent.metadata
  486. return get_declarative_base(parent)
  487. except AttributeError:
  488. pass
  489. return model
  490. def getdotattr(obj_or_class, dot_path, condition=None):
  491. """
  492. Allow dot-notated strings to be passed to `getattr`.
  493. ::
  494. getdotattr(SubSection, 'section.document')
  495. getdotattr(subsection, 'section.document')
  496. :param obj_or_class: Any object or class
  497. :param dot_path: Attribute path with dot mark as separator
  498. """
  499. last = obj_or_class
  500. for path in str(dot_path).split('.'):
  501. getter = attrgetter(path)
  502. if is_sequence(last):
  503. tmp = []
  504. for element in last:
  505. value = getter(element)
  506. if is_sequence(value):
  507. tmp.extend(value)
  508. else:
  509. tmp.append(value)
  510. last = tmp
  511. elif isinstance(last, InstrumentedAttribute):
  512. last = getter(last.property.mapper.class_)
  513. elif last is None:
  514. return None
  515. else:
  516. last = getter(last)
  517. if condition is not None:
  518. if is_sequence(last):
  519. last = [v for v in last if condition(v)]
  520. else:
  521. if not condition(last):
  522. return None
  523. return last
  524. def is_deleted(obj):
  525. return obj in sa.orm.object_session(obj).deleted
  526. def has_changes(obj, attrs=None, exclude=None):
  527. """
  528. Simple shortcut function for checking if given attributes of given
  529. declarative model object have changed during the session. Without
  530. parameters this checks if given object has any modificiations. Additionally
  531. exclude parameter can be given to check if given object has any changes
  532. in any attributes other than the ones given in exclude.
  533. ::
  534. from sqlalchemy_utils import has_changes
  535. user = User()
  536. has_changes(user, 'name') # False
  537. user.name = 'someone'
  538. has_changes(user, 'name') # True
  539. has_changes(user) # True
  540. You can check multiple attributes as well.
  541. ::
  542. has_changes(user, ['age']) # True
  543. has_changes(user, ['name', 'age']) # True
  544. This function also supports excluding certain attributes.
  545. ::
  546. has_changes(user, exclude=['name']) # False
  547. has_changes(user, exclude=['age']) # True
  548. .. versionchanged: 0.26.6
  549. Added support for multiple attributes and exclude parameter.
  550. :param obj: SQLAlchemy declarative model object
  551. :param attrs: Names of the attributes
  552. :param exclude: Names of the attributes to exclude
  553. """
  554. if attrs:
  555. if isinstance(attrs, str):
  556. return (
  557. sa.inspect(obj)
  558. .attrs
  559. .get(attrs)
  560. .history
  561. .has_changes()
  562. )
  563. else:
  564. return any(has_changes(obj, attr) for attr in attrs)
  565. else:
  566. if exclude is None:
  567. exclude = []
  568. return any(
  569. attr.history.has_changes()
  570. for key, attr in sa.inspect(obj).attrs.items()
  571. if key not in exclude
  572. )
  573. def is_loaded(obj, prop):
  574. """
  575. Return whether or not given property of given object has been loaded.
  576. ::
  577. class Article(Base):
  578. __tablename__ = 'article'
  579. id = sa.Column(sa.Integer, primary_key=True)
  580. name = sa.Column(sa.String)
  581. content = sa.orm.deferred(sa.Column(sa.String))
  582. article = session.query(Article).get(5)
  583. # name gets loaded since its not a deferred property
  584. assert is_loaded(article, 'name')
  585. # content has not yet been loaded since its a deferred property
  586. assert not is_loaded(article, 'content')
  587. .. versionadded: 0.27.8
  588. :param obj: SQLAlchemy declarative model object
  589. :param prop: Name of the property or InstrumentedAttribute
  590. """
  591. return prop not in sa.inspect(obj).unloaded
  592. def identity(obj_or_class):
  593. """
  594. Return the identity of given sqlalchemy declarative model class or instance
  595. as a tuple. This differs from obj._sa_instance_state.identity in a way that
  596. it always returns the identity even if object is still in transient state (
  597. new object that is not yet persisted into database). Also for classes it
  598. returns the identity attributes.
  599. ::
  600. from sqlalchemy import inspect
  601. from sqlalchemy_utils import identity
  602. user = User(name='John Matrix')
  603. session.add(user)
  604. identity(user) # None
  605. inspect(user).identity # None
  606. session.flush() # User now has id but is still in transient state
  607. identity(user) # (1,)
  608. inspect(user).identity # None
  609. session.commit()
  610. identity(user) # (1,)
  611. inspect(user).identity # (1, )
  612. You can also use identity for classes::
  613. identity(User) # (User.id, )
  614. .. versionadded: 0.21.0
  615. :param obj: SQLAlchemy declarative model object
  616. """
  617. return tuple(
  618. getattr(obj_or_class, column_key)
  619. for column_key in get_primary_keys(obj_or_class).keys()
  620. )
  621. def naturally_equivalent(obj, obj2):
  622. """
  623. Returns whether or not two given SQLAlchemy declarative instances are
  624. naturally equivalent (all their non primary key properties are equivalent).
  625. ::
  626. from sqlalchemy_utils import naturally_equivalent
  627. user = User(name='someone')
  628. user2 = User(name='someone')
  629. user == user2 # False
  630. naturally_equivalent(user, user2) # True
  631. :param obj: SQLAlchemy declarative model object
  632. :param obj2: SQLAlchemy declarative model object to compare with `obj`
  633. """
  634. for column_key, column in sa.inspect(obj.__class__).columns.items():
  635. if column.primary_key:
  636. continue
  637. if not (getattr(obj, column_key) == getattr(obj2, column_key)):
  638. return False
  639. return True
  640. def _get_class_registry(class_):
  641. try:
  642. return class_.registry._class_registry
  643. except AttributeError: # SQLAlchemy <1.4
  644. return class_._decl_class_registry