test_dialect.py 13 KB


  1. #! coding: utf-8
  2. # testing/suite/test_dialect.py
  3. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  4. # <see AUTHORS file>
  5. #
  6. # This module is part of SQLAlchemy and is released under
  7. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  8. from . import testing
  9. from .. import assert_raises
  10. from .. import config
  11. from .. import engines
  12. from .. import eq_
  13. from .. import fixtures
  14. from .. import ne_
  15. from .. import provide_metadata
  16. from ..config import requirements
  17. from ..provision import set_default_schema_on_connection
  18. from ..schema import Column
  19. from ..schema import Table
  20. from ... import bindparam
  21. from ... import event
  22. from ... import exc
  23. from ... import Integer
  24. from ... import literal_column
  25. from ... import select
  26. from ... import String
  27. from ...util import compat
  28. class ExceptionTest(fixtures.TablesTest):
  29. """Test basic exception wrapping.
  30. DBAPIs vary a lot in exception behavior so to actually anticipate
  31. specific exceptions from real round trips, we need to be conservative.
  32. """
  33. run_deletes = "each"
  34. __backend__ = True
  35. @classmethod
  36. def define_tables(cls, metadata):
  37. Table(
  38. "manual_pk",
  39. metadata,
  40. Column("id", Integer, primary_key=True, autoincrement=False),
  41. Column("data", String(50)),
  42. )
  43. @requirements.duplicate_key_raises_integrity_error
  44. def test_integrity_error(self):
  45. with config.db.connect() as conn:
  46. trans = conn.begin()
  47. conn.execute(
  48. self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
  49. )
  50. assert_raises(
  51. exc.IntegrityError,
  52. conn.execute,
  53. self.tables.manual_pk.insert(),
  54. {"id": 1, "data": "d1"},
  55. )
  56. trans.rollback()
  57. def test_exception_with_non_ascii(self):
  58. with config.db.connect() as conn:
  59. try:
  60. # try to create an error message that likely has non-ascii
  61. # characters in the DBAPI's message string. unfortunately
  62. # there's no way to make this happen with some drivers like
  63. # mysqlclient, pymysql. this at least does produce a non-
  64. # ascii error message for cx_oracle, psycopg2
  65. conn.execute(select(literal_column(u"méil")))
  66. assert False
  67. except exc.DBAPIError as err:
  68. err_str = str(err)
  69. assert str(err.orig) in str(err)
  70. # test that we are actually getting string on Py2k, unicode
  71. # on Py3k.
  72. if compat.py2k:
  73. assert isinstance(err_str, str)
  74. else:
  75. assert isinstance(err_str, str)
  76. class IsolationLevelTest(fixtures.TestBase):
  77. __backend__ = True
  78. __requires__ = ("isolation_level",)
  79. def _get_non_default_isolation_level(self):
  80. levels = requirements.get_isolation_levels(config)
  81. default = levels["default"]
  82. supported = levels["supported"]
  83. s = set(supported).difference(["AUTOCOMMIT", default])
  84. if s:
  85. return s.pop()
  86. else:
  87. config.skip_test("no non-default isolation level available")
  88. def test_default_isolation_level(self):
  89. eq_(
  90. config.db.dialect.default_isolation_level,
  91. requirements.get_isolation_levels(config)["default"],
  92. )
  93. def test_non_default_isolation_level(self):
  94. non_default = self._get_non_default_isolation_level()
  95. with config.db.connect() as conn:
  96. existing = conn.get_isolation_level()
  97. ne_(existing, non_default)
  98. conn.execution_options(isolation_level=non_default)
  99. eq_(conn.get_isolation_level(), non_default)
  100. conn.dialect.reset_isolation_level(conn.connection)
  101. eq_(conn.get_isolation_level(), existing)
  102. def test_all_levels(self):
  103. levels = requirements.get_isolation_levels(config)
  104. all_levels = levels["supported"]
  105. for level in set(all_levels).difference(["AUTOCOMMIT"]):
  106. with config.db.connect() as conn:
  107. conn.execution_options(isolation_level=level)
  108. eq_(conn.get_isolation_level(), level)
  109. trans = conn.begin()
  110. trans.rollback()
  111. eq_(conn.get_isolation_level(), level)
  112. with config.db.connect() as conn:
  113. eq_(
  114. conn.get_isolation_level(),
  115. levels["default"],
  116. )
  117. class AutocommitIsolationTest(fixtures.TablesTest):
  118. run_deletes = "each"
  119. __requires__ = ("autocommit",)
  120. __backend__ = True
  121. @classmethod
  122. def define_tables(cls, metadata):
  123. Table(
  124. "some_table",
  125. metadata,
  126. Column("id", Integer, primary_key=True, autoincrement=False),
  127. Column("data", String(50)),
  128. test_needs_acid=True,
  129. )
  130. def _test_conn_autocommits(self, conn, autocommit):
  131. trans = conn.begin()
  132. conn.execute(
  133. self.tables.some_table.insert(), {"id": 1, "data": "some data"}
  134. )
  135. trans.rollback()
  136. eq_(
  137. conn.scalar(select(self.tables.some_table.c.id)),
  138. 1 if autocommit else None,
  139. )
  140. with conn.begin():
  141. conn.execute(self.tables.some_table.delete())
  142. def test_autocommit_on(self, connection_no_trans):
  143. conn = connection_no_trans
  144. c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
  145. self._test_conn_autocommits(c2, True)
  146. c2.dialect.reset_isolation_level(c2.connection)
  147. self._test_conn_autocommits(conn, False)
  148. def test_autocommit_off(self, connection_no_trans):
  149. conn = connection_no_trans
  150. self._test_conn_autocommits(conn, False)
  151. def test_turn_autocommit_off_via_default_iso_level(
  152. self, connection_no_trans
  153. ):
  154. conn = connection_no_trans
  155. conn = conn.execution_options(isolation_level="AUTOCOMMIT")
  156. self._test_conn_autocommits(conn, True)
  157. conn.execution_options(
  158. isolation_level=requirements.get_isolation_levels(config)[
  159. "default"
  160. ]
  161. )
  162. self._test_conn_autocommits(conn, False)
  163. class EscapingTest(fixtures.TestBase):
  164. @provide_metadata
  165. def test_percent_sign_round_trip(self):
  166. """test that the DBAPI accommodates for escaped / nonescaped
  167. percent signs in a way that matches the compiler
  168. """
  169. m = self.metadata
  170. t = Table("t", m, Column("data", String(50)))
  171. t.create(config.db)
  172. with config.db.begin() as conn:
  173. conn.execute(t.insert(), dict(data="some % value"))
  174. conn.execute(t.insert(), dict(data="some %% other value"))
  175. eq_(
  176. conn.scalar(
  177. select(t.c.data).where(
  178. t.c.data == literal_column("'some % value'")
  179. )
  180. ),
  181. "some % value",
  182. )
  183. eq_(
  184. conn.scalar(
  185. select(t.c.data).where(
  186. t.c.data == literal_column("'some %% other value'")
  187. )
  188. ),
  189. "some %% other value",
  190. )
  191. class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
  192. __backend__ = True
  193. __requires__ = ("default_schema_name_switch",)
  194. def test_control_case(self):
  195. default_schema_name = config.db.dialect.default_schema_name
  196. eng = engines.testing_engine()
  197. with eng.connect():
  198. pass
  199. eq_(eng.dialect.default_schema_name, default_schema_name)
  200. def test_wont_work_wo_insert(self):
  201. default_schema_name = config.db.dialect.default_schema_name
  202. eng = engines.testing_engine()
  203. @event.listens_for(eng, "connect")
  204. def on_connect(dbapi_connection, connection_record):
  205. set_default_schema_on_connection(
  206. config, dbapi_connection, config.test_schema
  207. )
  208. with eng.connect() as conn:
  209. what_it_should_be = eng.dialect._get_default_schema_name(conn)
  210. eq_(what_it_should_be, config.test_schema)
  211. eq_(eng.dialect.default_schema_name, default_schema_name)
  212. def test_schema_change_on_connect(self):
  213. eng = engines.testing_engine()
  214. @event.listens_for(eng, "connect", insert=True)
  215. def on_connect(dbapi_connection, connection_record):
  216. set_default_schema_on_connection(
  217. config, dbapi_connection, config.test_schema
  218. )
  219. with eng.connect() as conn:
  220. what_it_should_be = eng.dialect._get_default_schema_name(conn)
  221. eq_(what_it_should_be, config.test_schema)
  222. eq_(eng.dialect.default_schema_name, config.test_schema)
  223. def test_schema_change_works_w_transactions(self):
  224. eng = engines.testing_engine()
  225. @event.listens_for(eng, "connect", insert=True)
  226. def on_connect(dbapi_connection, *arg):
  227. set_default_schema_on_connection(
  228. config, dbapi_connection, config.test_schema
  229. )
  230. with eng.connect() as conn:
  231. trans = conn.begin()
  232. what_it_should_be = eng.dialect._get_default_schema_name(conn)
  233. eq_(what_it_should_be, config.test_schema)
  234. trans.rollback()
  235. what_it_should_be = eng.dialect._get_default_schema_name(conn)
  236. eq_(what_it_should_be, config.test_schema)
  237. eq_(eng.dialect.default_schema_name, config.test_schema)
  238. class FutureWeCanSetDefaultSchemaWEventsTest(
  239. fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
  240. ):
  241. pass
  242. class DifficultParametersTest(fixtures.TestBase):
  243. __backend__ = True
  244. tough_parameters = testing.combinations(
  245. ("boring",),
  246. ("per cent",),
  247. ("per % cent",),
  248. ("%percent",),
  249. ("par(ens)",),
  250. ("percent%(ens)yah",),
  251. ("col:ons",),
  252. ("_starts_with_underscore",),
  253. ("dot.s",),
  254. ("more :: %colons%",),
  255. ("_name",),
  256. ("___name",),
  257. ("[BracketsAndCase]",),
  258. ("42numbers",),
  259. ("percent%signs",),
  260. ("has spaces",),
  261. ("/slashes/",),
  262. ("more/slashes",),
  263. ("q?marks",),
  264. ("1param",),
  265. ("1col:on",),
  266. argnames="paramname",
  267. )
  268. @tough_parameters
  269. @config.requirements.unusual_column_name_characters
  270. def test_round_trip_same_named_column(
  271. self, paramname, connection, metadata
  272. ):
  273. name = paramname
  274. t = Table(
  275. "t",
  276. metadata,
  277. Column("id", Integer, primary_key=True),
  278. Column(name, String(50), nullable=False),
  279. )
  280. # table is created
  281. t.create(connection)
  282. # automatic param generated by insert
  283. connection.execute(t.insert().values({"id": 1, name: "some name"}))
  284. # automatic param generated by criteria, plus selecting the column
  285. stmt = select(t.c[name]).where(t.c[name] == "some name")
  286. eq_(connection.scalar(stmt), "some name")
  287. # use the name in a param explicitly
  288. stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
  289. row = connection.execute(stmt, {name: "some name"}).first()
  290. # name works as the key from cursor.description
  291. eq_(row._mapping[name], "some name")
  292. # use expanding IN
  293. stmt = select(t.c[name]).where(
  294. t.c[name].in_(["some name", "some other_name"])
  295. )
  296. row = connection.execute(stmt).first()
  297. @testing.fixture
  298. def multirow_fixture(self, metadata, connection):
  299. mytable = Table(
  300. "mytable",
  301. metadata,
  302. Column("myid", Integer),
  303. Column("name", String(50)),
  304. Column("desc", String(50)),
  305. )
  306. mytable.create(connection)
  307. connection.execute(
  308. mytable.insert(),
  309. [
  310. {"myid": 1, "name": "a", "desc": "a_desc"},
  311. {"myid": 2, "name": "b", "desc": "b_desc"},
  312. {"myid": 3, "name": "c", "desc": "c_desc"},
  313. {"myid": 4, "name": "d", "desc": "d_desc"},
  314. ],
  315. )
  316. yield mytable
  317. @tough_parameters
  318. def test_standalone_bindparam_escape(
  319. self, paramname, connection, multirow_fixture
  320. ):
  321. tbl1 = multirow_fixture
  322. stmt = select(tbl1.c.myid).where(
  323. tbl1.c.name == bindparam(paramname, value="x")
  324. )
  325. res = connection.scalar(stmt, {paramname: "c"})
  326. eq_(res, 3)
  327. @tough_parameters
  328. def test_standalone_bindparam_escape_expanding(
  329. self, paramname, connection, multirow_fixture
  330. ):
  331. tbl1 = multirow_fixture
  332. stmt = (
  333. select(tbl1.c.myid)
  334. .where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
  335. .order_by(tbl1.c.myid)
  336. )
  337. res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
  338. eq_(res, [1, 4])