aggregates.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. """
  2. SQLAlchemy-Utils provides way of automatically calculating aggregate values of
  3. related models and saving them to parent model.
  4. This solution is inspired by RoR counter cache,
  5. `counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
  6. Why?
  7. ----
  8. Many times you may have situations where you need to calculate dynamically some
  9. aggregate value for given model. Some simple examples include:
  10. - Number of products in a catalog
  11. - Average rating for movie
  12. - Latest forum post
  13. - Total price of orders for given customer
  14. Now all these aggregates can be elegantly implemented with SQLAlchemy
  15. column_property_ function. However when your data grows calculating these
  16. values on the fly might start to hurt the performance of your application. The
  17. more aggregates you are using the more performance penalty you get.
  18. This module provides way of calculating these values automatically and
  19. efficiently at the time of modification rather than on the fly.
  20. Features
  21. --------
  22. * Automatically updates aggregate columns when aggregated values change
  23. * Supports aggregate values through arbitrary number levels of relations
  24. * Highly optimized: uses single query per transaction per aggregate column
  25. * Aggregated columns can be of any data type and use any selectable scalar
  26. expression
  27. .. _column_property:
  28. https://docs.sqlalchemy.org/en/latest/orm/mapped_sql_expr.html#using-column-property
  29. .. _counter_culture: https://github.com/magnusvk/counter_culture
  30. .. _stackoverflow reply by Michael Bayer:
  31. https://stackoverflow.com/a/13765857/520932
  32. Simple aggregates
  33. -----------------
  34. ::
  35. from sqlalchemy_utils import aggregated
  36. class Thread(Base):
  37. __tablename__ = 'thread'
  38. id = sa.Column(sa.Integer, primary_key=True)
  39. name = sa.Column(sa.Unicode(255))
  40. @aggregated('comments', sa.Column(sa.Integer))
  41. def comment_count(self):
  42. return sa.func.count('1')
  43. comments = sa.orm.relationship(
  44. 'Comment',
  45. backref='thread'
  46. )
  47. class Comment(Base):
  48. __tablename__ = 'comment'
  49. id = sa.Column(sa.Integer, primary_key=True)
  50. content = sa.Column(sa.UnicodeText)
  51. thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
  52. thread = Thread(name='SQLAlchemy development')
  53. thread.comments.append(Comment('Going good!'))
  54. thread.comments.append(Comment('Great new features!'))
  55. session.add(thread)
  56. session.commit()
  57. thread.comment_count # 2
  58. Custom aggregate expressions
  59. ----------------------------
  60. Aggregate expression can be virtually any SQL expression not just a simple
  61. function taking one parameter. You can try things such as subqueries and
  62. different kinds of functions.
  63. In the following example we have a Catalog of products where each catalog
  64. knows the net worth of its products.
  65. ::
  66. from sqlalchemy_utils import aggregated
  67. class Catalog(Base):
  68. __tablename__ = 'catalog'
  69. id = sa.Column(sa.Integer, primary_key=True)
  70. name = sa.Column(sa.Unicode(255))
  71. @aggregated('products', sa.Column(sa.Integer))
  72. def net_worth(self):
  73. return sa.func.sum(Product.price)
  74. products = sa.orm.relationship('Product')
  75. class Product(Base):
  76. __tablename__ = 'product'
  77. id = sa.Column(sa.Integer, primary_key=True)
  78. name = sa.Column(sa.Unicode(255))
  79. price = sa.Column(sa.Numeric)
  80. catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
  81. Now the net_worth column of Catalog model will be automatically whenever:
  82. * A new product is added to the catalog
  83. * A product is deleted from the catalog
  84. * The price of catalog product is changed
  85. ::
  86. from decimal import Decimal
  87. product1 = Product(name='Some product', price=Decimal(1000))
  88. product2 = Product(name='Some other product', price=Decimal(500))
  89. catalog = Catalog(
  90. name='My first catalog',
  91. products=[
  92. product1,
  93. product2
  94. ]
  95. )
  96. session.add(catalog)
  97. session.commit()
  98. session.refresh(catalog)
  99. catalog.net_worth # 1500
  100. session.delete(product2)
  101. session.commit()
  102. session.refresh(catalog)
  103. catalog.net_worth # 1000
  104. product1.price = 2000
  105. session.commit()
  106. session.refresh(catalog)
  107. catalog.net_worth # 2000
  108. Multiple aggregates per class
  109. -----------------------------
  110. Sometimes you may need to define multiple aggregate values for same class. If
  111. you need to define lots of relationships pointing to same class, remember to
  112. define the relationships as viewonly when possible.
  113. ::
  114. from sqlalchemy_utils import aggregated
  115. class Customer(Base):
  116. __tablename__ = 'customer'
  117. id = sa.Column(sa.Integer, primary_key=True)
  118. name = sa.Column(sa.Unicode(255))
  119. @aggregated('orders', sa.Column(sa.Integer))
  120. def orders_sum(self):
  121. return sa.func.sum(Order.price)
  122. @aggregated('invoiced_orders', sa.Column(sa.Integer))
  123. def invoiced_orders_sum(self):
  124. return sa.func.sum(Order.price)
  125. orders = sa.orm.relationship('Order')
  126. invoiced_orders = sa.orm.relationship(
  127. 'Order',
  128. primaryjoin=
  129. 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
  130. viewonly=True
  131. )
  132. class Order(Base):
  133. __tablename__ = 'order'
  134. id = sa.Column(sa.Integer, primary_key=True)
  135. name = sa.Column(sa.Unicode(255))
  136. price = sa.Column(sa.Numeric)
  137. invoiced = sa.Column(sa.Boolean, default=False)
  138. customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
  139. Many-to-Many aggregates
  140. -----------------------
  141. Aggregate expressions also support many-to-many relationships. The usual use
  142. scenarios includes things such as:
  143. 1. Friend count of a user
  144. 2. Group count where given user belongs to
  145. ::
  146. user_group = sa.Table('user_group', Base.metadata,
  147. sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
  148. sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
  149. )
  150. class User(Base):
  151. __tablename__ = 'user'
  152. id = sa.Column(sa.Integer, primary_key=True)
  153. name = sa.Column(sa.Unicode(255))
  154. @aggregated('groups', sa.Column(sa.Integer, default=0))
  155. def group_count(self):
  156. return sa.func.count('1')
  157. groups = sa.orm.relationship(
  158. 'Group',
  159. backref='users',
  160. secondary=user_group
  161. )
  162. class Group(Base):
  163. __tablename__ = 'group'
  164. id = sa.Column(sa.Integer, primary_key=True)
  165. name = sa.Column(sa.Unicode(255))
  166. user = User(name='John Matrix')
  167. user.groups = [Group(name='Group A'), Group(name='Group B')]
  168. session.add(user)
  169. session.commit()
  170. session.refresh(user)
  171. user.group_count # 2
  172. Multi-level aggregates
  173. ----------------------
  174. Aggregates can span across multiple relationships. In the following example
  175. each Catalog has a net_worth which is the sum of all products in all
  176. categories.
  177. ::
  178. from sqlalchemy_utils import aggregated
  179. class Catalog(Base):
  180. __tablename__ = 'catalog'
  181. id = sa.Column(sa.Integer, primary_key=True)
  182. name = sa.Column(sa.Unicode(255))
  183. @aggregated('categories.products', sa.Column(sa.Integer))
  184. def net_worth(self):
  185. return sa.func.sum(Product.price)
  186. categories = sa.orm.relationship('Category')
  187. class Category(Base):
  188. __tablename__ = 'category'
  189. id = sa.Column(sa.Integer, primary_key=True)
  190. name = sa.Column(sa.Unicode(255))
  191. catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
  192. products = sa.orm.relationship('Product')
  193. class Product(Base):
  194. __tablename__ = 'product'
  195. id = sa.Column(sa.Integer, primary_key=True)
  196. name = sa.Column(sa.Unicode(255))
  197. price = sa.Column(sa.Numeric)
  198. category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
  199. Examples
  200. --------
  201. Average movie rating
  202. ^^^^^^^^^^^^^^^^^^^^
  203. ::
  204. from sqlalchemy_utils import aggregated
  205. class Movie(Base):
  206. __tablename__ = 'movie'
  207. id = sa.Column(sa.Integer, primary_key=True)
  208. name = sa.Column(sa.Unicode(255))
  209. @aggregated('ratings', sa.Column(sa.Numeric))
  210. def avg_rating(self):
  211. return sa.func.avg(Rating.stars)
  212. ratings = sa.orm.relationship('Rating')
  213. class Rating(Base):
  214. __tablename__ = 'rating'
  215. id = sa.Column(sa.Integer, primary_key=True)
  216. stars = sa.Column(sa.Integer)
  217. movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id))
  218. movie = Movie('Terminator 2')
  219. movie.ratings.append(Rating(stars=5))
  220. movie.ratings.append(Rating(stars=4))
  221. movie.ratings.append(Rating(stars=3))
  222. session.add(movie)
  223. session.commit()
  224. movie.avg_rating # 4
  225. TODO
  226. ----
  227. * Special consideration should be given to `deadlocks`_.
  228. .. _deadlocks:
  229. https://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
  230. """
  231. from collections import defaultdict
  232. from weakref import WeakKeyDictionary
  233. import sqlalchemy as sa
  234. import sqlalchemy.event
  235. import sqlalchemy.orm
  236. from sqlalchemy.ext.declarative import declared_attr
  237. from sqlalchemy.sql.functions import _FunctionGenerator
  238. from .compat import _select_args, get_scalar_subquery
  239. from .functions.orm import get_column_key
  240. from .relationships import (
  241. chained_join,
  242. path_to_relationships,
  243. select_correlated_expression
  244. )
  245. aggregated_attrs = WeakKeyDictionary()
  246. class AggregatedAttribute(declared_attr):
  247. def __init__(
  248. self,
  249. fget,
  250. relationship,
  251. column,
  252. *args,
  253. **kwargs
  254. ):
  255. super().__init__(fget, *args, **kwargs)
  256. self.__doc__ = fget.__doc__
  257. self.column = column
  258. self.relationship = relationship
  259. def __get__(desc, self, cls):
  260. value = (desc.fget, desc.relationship, desc.column)
  261. if cls not in aggregated_attrs:
  262. aggregated_attrs[cls] = [value]
  263. else:
  264. aggregated_attrs[cls].append(value)
  265. return desc.column
  266. def local_condition(prop, objects):
  267. pairs = prop.local_remote_pairs
  268. if prop.secondary is not None:
  269. parent_column = pairs[1][0]
  270. fetched_column = pairs[1][0]
  271. else:
  272. parent_column = pairs[0][0]
  273. fetched_column = pairs[0][1]
  274. key = get_column_key(prop.mapper, fetched_column)
  275. values = []
  276. for obj in objects:
  277. try:
  278. values.append(getattr(obj, key))
  279. except sa.orm.exc.ObjectDeletedError:
  280. pass
  281. if values:
  282. return parent_column.in_(values)
  283. def aggregate_expression(expr, class_):
  284. if isinstance(expr, sa.sql.visitors.Visitable):
  285. return expr
  286. elif isinstance(expr, _FunctionGenerator):
  287. return expr(sa.sql.text('1'))
  288. else:
  289. return expr(class_)
  290. class AggregatedValue:
  291. def __init__(self, class_, attr, path, expr):
  292. self.class_ = class_
  293. self.attr = attr
  294. self.path = path
  295. self.relationships = list(
  296. reversed(path_to_relationships(path, class_))
  297. )
  298. self.expr = aggregate_expression(expr, class_)
  299. @property
  300. def aggregate_query(self):
  301. query = select_correlated_expression(
  302. self.class_,
  303. self.expr,
  304. self.path,
  305. self.relationships[0].mapper.class_
  306. )
  307. return get_scalar_subquery(query)
  308. def update_query(self, objects):
  309. table = self.class_.__table__
  310. query = table.update().values(
  311. {self.attr: self.aggregate_query}
  312. )
  313. if len(self.relationships) == 1:
  314. prop = self.relationships[-1].property
  315. condition = local_condition(prop, objects)
  316. if condition is not None:
  317. return query.where(condition)
  318. else:
  319. # Builds query such as:
  320. #
  321. # UPDATE catalog SET product_count = (aggregate_query)
  322. # WHERE id IN (
  323. # SELECT catalog_id
  324. # FROM category
  325. # INNER JOIN sub_category
  326. # ON category.id = sub_category.category_id
  327. # WHERE sub_category.id IN (product_sub_category_ids)
  328. # )
  329. property_ = self.relationships[-1].property
  330. remote_pairs = property_.local_remote_pairs
  331. local = remote_pairs[0][0]
  332. remote = remote_pairs[0][1]
  333. condition = local_condition(
  334. self.relationships[0].property,
  335. objects
  336. )
  337. if condition is not None:
  338. return query.where(
  339. local.in_(
  340. sa.select(
  341. *_select_args(remote)
  342. ).select_from(
  343. chained_join(*reversed(self.relationships))
  344. ).where(
  345. condition
  346. )
  347. )
  348. )
  349. class AggregationManager:
  350. def __init__(self):
  351. self.reset()
  352. def reset(self):
  353. self.generator_registry = defaultdict(list)
  354. def register_listeners(self):
  355. sa.event.listen(
  356. sa.orm.Mapper,
  357. 'after_configured',
  358. self.update_generator_registry
  359. )
  360. sa.event.listen(
  361. sa.orm.session.Session,
  362. 'after_flush',
  363. self.construct_aggregate_queries
  364. )
  365. def update_generator_registry(self):
  366. for class_, attrs in aggregated_attrs.items():
  367. for expr, path, column in attrs:
  368. value = AggregatedValue(
  369. class_=class_,
  370. attr=column,
  371. path=path,
  372. expr=expr(class_)
  373. )
  374. key = value.relationships[0].mapper.class_
  375. self.generator_registry[key].append(
  376. value
  377. )
  378. def construct_aggregate_queries(self, session, ctx):
  379. object_dict = defaultdict(list)
  380. for obj in session:
  381. for class_ in self.generator_registry:
  382. if isinstance(obj, class_):
  383. object_dict[class_].append(obj)
  384. for class_, objects in object_dict.items():
  385. for aggregate_value in self.generator_registry[class_]:
  386. query = aggregate_value.update_query(objects)
  387. if query is not None:
  388. session.execute(query)
  389. manager = AggregationManager()
  390. manager.register_listeners()
  391. def aggregated(
  392. relationship,
  393. column
  394. ):
  395. """
  396. Decorator that generates an aggregated attribute. The decorated function
  397. should return an aggregate select expression.
  398. :param relationship:
  399. Defines the relationship of which the aggregate is calculated from.
  400. The class needs to have given relationship in order to calculate the
  401. aggregate.
  402. :param column:
  403. SQLAlchemy Column object. The column definition of this aggregate
  404. attribute.
  405. """
  406. def wraps(func):
  407. return AggregatedAttribute(
  408. func,
  409. relationship,
  410. column
  411. )
  412. return wraps