test_rowcount.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # testing/suite/test_rowcount.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. from sqlalchemy import bindparam
  8. from sqlalchemy import Column
  9. from sqlalchemy import Integer
  10. from sqlalchemy import select
  11. from sqlalchemy import String
  12. from sqlalchemy import Table
  13. from sqlalchemy import testing
  14. from sqlalchemy import text
  15. from sqlalchemy.testing import eq_
  16. from sqlalchemy.testing import fixtures
  17. class RowCountTest(fixtures.TablesTest):
  18. """test rowcount functionality"""
  19. __requires__ = ("sane_rowcount",)
  20. __backend__ = True
  21. @classmethod
  22. def define_tables(cls, metadata):
  23. Table(
  24. "employees",
  25. metadata,
  26. Column(
  27. "employee_id",
  28. Integer,
  29. autoincrement=False,
  30. primary_key=True,
  31. ),
  32. Column("name", String(50)),
  33. Column("department", String(1)),
  34. )
  35. @classmethod
  36. def insert_data(cls, connection):
  37. cls.data = data = [
  38. ("Angela", "A"),
  39. ("Andrew", "A"),
  40. ("Anand", "A"),
  41. ("Bob", "B"),
  42. ("Bobette", "B"),
  43. ("Buffy", "B"),
  44. ("Charlie", "C"),
  45. ("Cynthia", "C"),
  46. ("Chris", "C"),
  47. ]
  48. employees_table = cls.tables.employees
  49. connection.execute(
  50. employees_table.insert(),
  51. [
  52. {"employee_id": i, "name": n, "department": d}
  53. for i, (n, d) in enumerate(data)
  54. ],
  55. )
  56. def test_basic(self, connection):
  57. employees_table = self.tables.employees
  58. s = select(
  59. employees_table.c.name, employees_table.c.department
  60. ).order_by(employees_table.c.employee_id)
  61. rows = connection.execute(s).fetchall()
  62. eq_(rows, self.data)
  63. def test_update_rowcount1(self, connection):
  64. employees_table = self.tables.employees
  65. # WHERE matches 3, 3 rows changed
  66. department = employees_table.c.department
  67. r = connection.execute(
  68. employees_table.update().where(department == "C"),
  69. {"department": "Z"},
  70. )
  71. assert r.rowcount == 3
  72. def test_update_rowcount2(self, connection):
  73. employees_table = self.tables.employees
  74. # WHERE matches 3, 0 rows changed
  75. department = employees_table.c.department
  76. r = connection.execute(
  77. employees_table.update().where(department == "C"),
  78. {"department": "C"},
  79. )
  80. eq_(r.rowcount, 3)
  81. @testing.requires.sane_rowcount_w_returning
  82. def test_update_rowcount_return_defaults(self, connection):
  83. employees_table = self.tables.employees
  84. department = employees_table.c.department
  85. stmt = (
  86. employees_table.update()
  87. .where(department == "C")
  88. .values(name=employees_table.c.department + "Z")
  89. .return_defaults()
  90. )
  91. r = connection.execute(stmt)
  92. eq_(r.rowcount, 3)
  93. def test_raw_sql_rowcount(self, connection):
  94. # test issue #3622, make sure eager rowcount is called for text
  95. result = connection.exec_driver_sql(
  96. "update employees set department='Z' where department='C'"
  97. )
  98. eq_(result.rowcount, 3)
  99. def test_text_rowcount(self, connection):
  100. # test issue #3622, make sure eager rowcount is called for text
  101. result = connection.execute(
  102. text("update employees set department='Z' " "where department='C'")
  103. )
  104. eq_(result.rowcount, 3)
  105. def test_delete_rowcount(self, connection):
  106. employees_table = self.tables.employees
  107. # WHERE matches 3, 3 rows deleted
  108. department = employees_table.c.department
  109. r = connection.execute(
  110. employees_table.delete().where(department == "C")
  111. )
  112. eq_(r.rowcount, 3)
  113. @testing.requires.sane_multi_rowcount
  114. def test_multi_update_rowcount(self, connection):
  115. employees_table = self.tables.employees
  116. stmt = (
  117. employees_table.update()
  118. .where(employees_table.c.name == bindparam("emp_name"))
  119. .values(department="C")
  120. )
  121. r = connection.execute(
  122. stmt,
  123. [
  124. {"emp_name": "Bob"},
  125. {"emp_name": "Cynthia"},
  126. {"emp_name": "nonexistent"},
  127. ],
  128. )
  129. eq_(r.rowcount, 2)
  130. @testing.requires.sane_multi_rowcount
  131. def test_multi_delete_rowcount(self, connection):
  132. employees_table = self.tables.employees
  133. stmt = employees_table.delete().where(
  134. employees_table.c.name == bindparam("emp_name")
  135. )
  136. r = connection.execute(
  137. stmt,
  138. [
  139. {"emp_name": "Bob"},
  140. {"emp_name": "Cynthia"},
  141. {"emp_name": "nonexistent"},
  142. ],
  143. )
  144. eq_(r.rowcount, 2)