#! coding: utf-8 # testing/suite/test_dialect.py # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php from . import testing from .. import assert_raises from .. import config from .. import engines from .. import eq_ from .. import fixtures from .. import ne_ from .. import provide_metadata from ..config import requirements from ..provision import set_default_schema_on_connection from ..schema import Column from ..schema import Table from ... import bindparam from ... import event from ... import exc from ... import Integer from ... import literal_column from ... import select from ... import String from ...util import compat class ExceptionTest(fixtures.TablesTest): """Test basic exception wrapping. DBAPIs vary a lot in exception behavior so to actually anticipate specific exceptions from real round trips, we need to be conservative. """ run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): Table( "manual_pk", metadata, Column("id", Integer, primary_key=True, autoincrement=False), Column("data", String(50)), ) @requirements.duplicate_key_raises_integrity_error def test_integrity_error(self): with config.db.connect() as conn: trans = conn.begin() conn.execute( self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} ) assert_raises( exc.IntegrityError, conn.execute, self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}, ) trans.rollback() def test_exception_with_non_ascii(self): with config.db.connect() as conn: try: # try to create an error message that likely has non-ascii # characters in the DBAPI's message string. unfortunately # there's no way to make this happen with some drivers like # mysqlclient, pymysql. this at least does produce a non- # ascii error message for cx_oracle, psycopg2 conn.execute(select(literal_column(u"méil"))) assert False except exc.DBAPIError as err: err_str = str(err) assert str(err.orig) in str(err) # test that we are actually getting string on Py2k, unicode # on Py3k. if compat.py2k: assert isinstance(err_str, str) else: assert isinstance(err_str, str) class IsolationLevelTest(fixtures.TestBase): __backend__ = True __requires__ = ("isolation_level",) def _get_non_default_isolation_level(self): levels = requirements.get_isolation_levels(config) default = levels["default"] supported = levels["supported"] s = set(supported).difference(["AUTOCOMMIT", default]) if s: return s.pop() else: config.skip_test("no non-default isolation level available") def test_default_isolation_level(self): eq_( config.db.dialect.default_isolation_level, requirements.get_isolation_levels(config)["default"], ) def test_non_default_isolation_level(self): non_default = self._get_non_default_isolation_level() with config.db.connect() as conn: existing = conn.get_isolation_level() ne_(existing, non_default) conn.execution_options(isolation_level=non_default) eq_(conn.get_isolation_level(), non_default) conn.dialect.reset_isolation_level(conn.connection) eq_(conn.get_isolation_level(), existing) def test_all_levels(self): levels = requirements.get_isolation_levels(config) all_levels = levels["supported"] for level in set(all_levels).difference(["AUTOCOMMIT"]): with config.db.connect() as conn: conn.execution_options(isolation_level=level) eq_(conn.get_isolation_level(), level) trans = conn.begin() trans.rollback() eq_(conn.get_isolation_level(), level) with config.db.connect() as conn: eq_( conn.get_isolation_level(), levels["default"], ) class AutocommitIsolationTest(fixtures.TablesTest): run_deletes = "each" __requires__ = ("autocommit",) __backend__ = True @classmethod def define_tables(cls, metadata): Table( "some_table", metadata, Column("id", Integer, primary_key=True, autoincrement=False), Column("data", String(50)), test_needs_acid=True, ) def _test_conn_autocommits(self, conn, autocommit): trans = conn.begin() conn.execute( self.tables.some_table.insert(), {"id": 1, "data": "some data"} ) trans.rollback() eq_( conn.scalar(select(self.tables.some_table.c.id)), 1 if autocommit else None, ) with conn.begin(): conn.execute(self.tables.some_table.delete()) def test_autocommit_on(self, connection_no_trans): conn = connection_no_trans c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) c2.dialect.reset_isolation_level(c2.connection) self._test_conn_autocommits(conn, False) def test_autocommit_off(self, connection_no_trans): conn = connection_no_trans self._test_conn_autocommits(conn, False) def test_turn_autocommit_off_via_default_iso_level( self, connection_no_trans ): conn = connection_no_trans conn = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(conn, True) conn.execution_options( isolation_level=requirements.get_isolation_levels(config)[ "default" ] ) self._test_conn_autocommits(conn, False) class EscapingTest(fixtures.TestBase): @provide_metadata def test_percent_sign_round_trip(self): """test that the DBAPI accommodates for escaped / nonescaped percent signs in a way that matches the compiler """ m = self.metadata t = Table("t", m, Column("data", String(50))) t.create(config.db) with config.db.begin() as conn: conn.execute(t.insert(), dict(data="some % value")) conn.execute(t.insert(), dict(data="some %% other value")) eq_( conn.scalar( select(t.c.data).where( t.c.data == literal_column("'some % value'") ) ), "some % value", ) eq_( conn.scalar( select(t.c.data).where( t.c.data == literal_column("'some %% other value'") ) ), "some %% other value", ) class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase): __backend__ = True __requires__ = ("default_schema_name_switch",) def test_control_case(self): default_schema_name = config.db.dialect.default_schema_name eng = engines.testing_engine() with eng.connect(): pass eq_(eng.dialect.default_schema_name, default_schema_name) def test_wont_work_wo_insert(self): default_schema_name = config.db.dialect.default_schema_name eng = engines.testing_engine() @event.listens_for(eng, "connect") def on_connect(dbapi_connection, connection_record): set_default_schema_on_connection( config, dbapi_connection, config.test_schema ) with eng.connect() as conn: what_it_should_be = eng.dialect._get_default_schema_name(conn) eq_(what_it_should_be, config.test_schema) eq_(eng.dialect.default_schema_name, default_schema_name) def test_schema_change_on_connect(self): eng = engines.testing_engine() @event.listens_for(eng, "connect", insert=True) def on_connect(dbapi_connection, connection_record): set_default_schema_on_connection( config, dbapi_connection, config.test_schema ) with eng.connect() as conn: what_it_should_be = eng.dialect._get_default_schema_name(conn) eq_(what_it_should_be, config.test_schema) eq_(eng.dialect.default_schema_name, config.test_schema) def test_schema_change_works_w_transactions(self): eng = engines.testing_engine() @event.listens_for(eng, "connect", insert=True) def on_connect(dbapi_connection, *arg): set_default_schema_on_connection( config, dbapi_connection, config.test_schema ) with eng.connect() as conn: trans = conn.begin() what_it_should_be = eng.dialect._get_default_schema_name(conn) eq_(what_it_should_be, config.test_schema) trans.rollback() what_it_should_be = eng.dialect._get_default_schema_name(conn) eq_(what_it_should_be, config.test_schema) eq_(eng.dialect.default_schema_name, config.test_schema) class FutureWeCanSetDefaultSchemaWEventsTest( fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest ): pass class DifficultParametersTest(fixtures.TestBase): __backend__ = True tough_parameters = testing.combinations( ("boring",), ("per cent",), ("per % cent",), ("%percent",), ("par(ens)",), ("percent%(ens)yah",), ("col:ons",), ("_starts_with_underscore",), ("dot.s",), ("more :: %colons%",), ("_name",), ("___name",), ("[BracketsAndCase]",), ("42numbers",), ("percent%signs",), ("has spaces",), ("/slashes/",), ("more/slashes",), ("q?marks",), ("1param",), ("1col:on",), argnames="paramname", ) @tough_parameters @config.requirements.unusual_column_name_characters def test_round_trip_same_named_column( self, paramname, connection, metadata ): name = paramname t = Table( "t", metadata, Column("id", Integer, primary_key=True), Column(name, String(50), nullable=False), ) # table is created t.create(connection) # automatic param generated by insert connection.execute(t.insert().values({"id": 1, name: "some name"})) # automatic param generated by criteria, plus selecting the column stmt = select(t.c[name]).where(t.c[name] == "some name") eq_(connection.scalar(stmt), "some name") # use the name in a param explicitly stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) row = connection.execute(stmt, {name: "some name"}).first() # name works as the key from cursor.description eq_(row._mapping[name], "some name") # use expanding IN stmt = select(t.c[name]).where( t.c[name].in_(["some name", "some other_name"]) ) row = connection.execute(stmt).first() @testing.fixture def multirow_fixture(self, metadata, connection): mytable = Table( "mytable", metadata, Column("myid", Integer), Column("name", String(50)), Column("desc", String(50)), ) mytable.create(connection) connection.execute( mytable.insert(), [ {"myid": 1, "name": "a", "desc": "a_desc"}, {"myid": 2, "name": "b", "desc": "b_desc"}, {"myid": 3, "name": "c", "desc": "c_desc"}, {"myid": 4, "name": "d", "desc": "d_desc"}, ], ) yield mytable @tough_parameters def test_standalone_bindparam_escape( self, paramname, connection, multirow_fixture ): tbl1 = multirow_fixture stmt = select(tbl1.c.myid).where( tbl1.c.name == bindparam(paramname, value="x") ) res = connection.scalar(stmt, {paramname: "c"}) eq_(res, 3) @tough_parameters def test_standalone_bindparam_escape_expanding( self, paramname, connection, multirow_fixture ): tbl1 = multirow_fixture stmt = ( select(tbl1.c.myid) .where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"]))) .order_by(tbl1.c.myid) ) res = connection.scalars(stmt, {paramname: ["d", "a"]}).all() eq_(res, [1, 4])