asyncmy.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # dialects/mysql/asyncmy.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
  3. # file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. r"""
  8. .. dialect:: mysql+asyncmy
  9. :name: asyncmy
  10. :dbapi: asyncmy
  11. :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
  12. :url: https://github.com/long2ice/asyncmy
  13. Using a special asyncio mediation layer, the asyncmy dialect is usable
  14. as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
  15. extension package.
  16. This dialect should normally be used only with the
  17. :func:`_asyncio.create_async_engine` engine creation function::
  18. from sqlalchemy.ext.asyncio import create_async_engine
  19. engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
  20. """ # noqa
  21. from .pymysql import MySQLDialect_pymysql
  22. from ... import pool
  23. from ... import util
  24. from ...engine import AdaptedConnection
  25. from ...util.concurrency import asynccontextmanager
  26. from ...util.concurrency import asyncio
  27. from ...util.concurrency import await_fallback
  28. from ...util.concurrency import await_only
  29. class AsyncAdapt_asyncmy_cursor:
  30. server_side = False
  31. __slots__ = (
  32. "_adapt_connection",
  33. "_connection",
  34. "await_",
  35. "_cursor",
  36. "_rows",
  37. )
  38. def __init__(self, adapt_connection):
  39. self._adapt_connection = adapt_connection
  40. self._connection = adapt_connection._connection
  41. self.await_ = adapt_connection.await_
  42. cursor = self._connection.cursor()
  43. self._cursor = self.await_(cursor.__aenter__())
  44. self._rows = []
  45. @property
  46. def description(self):
  47. return self._cursor.description
  48. @property
  49. def rowcount(self):
  50. return self._cursor.rowcount
  51. @property
  52. def arraysize(self):
  53. return self._cursor.arraysize
  54. @arraysize.setter
  55. def arraysize(self, value):
  56. self._cursor.arraysize = value
  57. @property
  58. def lastrowid(self):
  59. return self._cursor.lastrowid
  60. def close(self):
  61. # note we aren't actually closing the cursor here,
  62. # we are just letting GC do it. to allow this to be async
  63. # we would need the Result to change how it does "Safe close cursor".
  64. # MySQL "cursors" don't actually have state to be "closed" besides
  65. # exhausting rows, which we already have done for sync cursor.
  66. # another option would be to emulate aiosqlite dialect and assign
  67. # cursor only if we are doing server side cursor operation.
  68. self._rows[:] = []
  69. def execute(self, operation, parameters=None):
  70. return self.await_(self._execute_async(operation, parameters))
  71. def executemany(self, operation, seq_of_parameters):
  72. return self.await_(
  73. self._executemany_async(operation, seq_of_parameters)
  74. )
  75. async def _execute_async(self, operation, parameters):
  76. async with self._adapt_connection._mutex_and_adapt_errors():
  77. if parameters is None:
  78. result = await self._cursor.execute(operation)
  79. else:
  80. result = await self._cursor.execute(operation, parameters)
  81. if not self.server_side:
  82. # asyncmy has a "fake" async result, so we have to pull it out
  83. # of that here since our default result is not async.
  84. # we could just as easily grab "_rows" here and be done with it
  85. # but this is safer.
  86. self._rows = list(await self._cursor.fetchall())
  87. return result
  88. async def _executemany_async(self, operation, seq_of_parameters):
  89. async with self._adapt_connection._mutex_and_adapt_errors():
  90. return await self._cursor.executemany(operation, seq_of_parameters)
  91. def setinputsizes(self, *inputsizes):
  92. pass
  93. def __iter__(self):
  94. while self._rows:
  95. yield self._rows.pop(0)
  96. def fetchone(self):
  97. if self._rows:
  98. return self._rows.pop(0)
  99. else:
  100. return None
  101. def fetchmany(self, size=None):
  102. if size is None:
  103. size = self.arraysize
  104. retval = self._rows[0:size]
  105. self._rows[:] = self._rows[size:]
  106. return retval
  107. def fetchall(self):
  108. retval = self._rows[:]
  109. self._rows[:] = []
  110. return retval
  111. class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
  112. __slots__ = ()
  113. server_side = True
  114. def __init__(self, adapt_connection):
  115. self._adapt_connection = adapt_connection
  116. self._connection = adapt_connection._connection
  117. self.await_ = adapt_connection.await_
  118. cursor = self._connection.cursor(
  119. adapt_connection.dbapi.asyncmy.cursors.SSCursor
  120. )
  121. self._cursor = self.await_(cursor.__aenter__())
  122. def close(self):
  123. if self._cursor is not None:
  124. self.await_(self._cursor.close())
  125. self._cursor = None
  126. def fetchone(self):
  127. return self.await_(self._cursor.fetchone())
  128. def fetchmany(self, size=None):
  129. return self.await_(self._cursor.fetchmany(size=size))
  130. def fetchall(self):
  131. return self.await_(self._cursor.fetchall())
  132. class AsyncAdapt_asyncmy_connection(AdaptedConnection):
  133. await_ = staticmethod(await_only)
  134. __slots__ = ("dbapi", "_connection", "_execute_mutex")
  135. def __init__(self, dbapi, connection):
  136. self.dbapi = dbapi
  137. self._connection = connection
  138. self._execute_mutex = asyncio.Lock()
  139. @asynccontextmanager
  140. async def _mutex_and_adapt_errors(self):
  141. async with self._execute_mutex:
  142. try:
  143. yield
  144. except AttributeError:
  145. raise self.dbapi.InternalError(
  146. "network operation failed due to asyncmy attribute error"
  147. )
  148. def ping(self, reconnect):
  149. assert not reconnect
  150. return self.await_(self._do_ping())
  151. async def _do_ping(self):
  152. async with self._mutex_and_adapt_errors():
  153. return await self._connection.ping(False)
  154. def character_set_name(self):
  155. return self._connection.character_set_name()
  156. def autocommit(self, value):
  157. self.await_(self._connection.autocommit(value))
  158. def cursor(self, server_side=False):
  159. if server_side:
  160. return AsyncAdapt_asyncmy_ss_cursor(self)
  161. else:
  162. return AsyncAdapt_asyncmy_cursor(self)
  163. def rollback(self):
  164. self.await_(self._connection.rollback())
  165. def commit(self):
  166. self.await_(self._connection.commit())
  167. def close(self):
  168. # it's not awaitable.
  169. self._connection.close()
  170. class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
  171. __slots__ = ()
  172. await_ = staticmethod(await_fallback)
  173. def _Binary(x):
  174. """Return x as a binary type."""
  175. return bytes(x)
  176. class AsyncAdapt_asyncmy_dbapi:
  177. def __init__(self, asyncmy):
  178. self.asyncmy = asyncmy
  179. self.paramstyle = "format"
  180. self._init_dbapi_attributes()
  181. def _init_dbapi_attributes(self):
  182. for name in (
  183. "Warning",
  184. "Error",
  185. "InterfaceError",
  186. "DataError",
  187. "DatabaseError",
  188. "OperationalError",
  189. "InterfaceError",
  190. "IntegrityError",
  191. "ProgrammingError",
  192. "InternalError",
  193. "NotSupportedError",
  194. ):
  195. setattr(self, name, getattr(self.asyncmy.errors, name))
  196. STRING = util.symbol("STRING")
  197. NUMBER = util.symbol("NUMBER")
  198. BINARY = util.symbol("BINARY")
  199. DATETIME = util.symbol("DATETIME")
  200. TIMESTAMP = util.symbol("TIMESTAMP")
  201. Binary = staticmethod(_Binary)
  202. def connect(self, *arg, **kw):
  203. async_fallback = kw.pop("async_fallback", False)
  204. if util.asbool(async_fallback):
  205. return AsyncAdaptFallback_asyncmy_connection(
  206. self,
  207. await_fallback(self.asyncmy.connect(*arg, **kw)),
  208. )
  209. else:
  210. return AsyncAdapt_asyncmy_connection(
  211. self,
  212. await_only(self.asyncmy.connect(*arg, **kw)),
  213. )
  214. class MySQLDialect_asyncmy(MySQLDialect_pymysql):
  215. driver = "asyncmy"
  216. supports_statement_cache = True
  217. supports_server_side_cursors = True
  218. _sscursor = AsyncAdapt_asyncmy_ss_cursor
  219. is_async = True
  220. @classmethod
  221. def dbapi(cls):
  222. return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
  223. @classmethod
  224. def get_pool_class(cls, url):
  225. async_fallback = url.query.get("async_fallback", False)
  226. if util.asbool(async_fallback):
  227. return pool.FallbackAsyncAdaptedQueuePool
  228. else:
  229. return pool.AsyncAdaptedQueuePool
  230. def create_connect_args(self, url):
  231. return super(MySQLDialect_asyncmy, self).create_connect_args(
  232. url, _translate_args=dict(username="user", database="db")
  233. )
  234. def is_disconnect(self, e, connection, cursor):
  235. if super(MySQLDialect_asyncmy, self).is_disconnect(
  236. e, connection, cursor
  237. ):
  238. return True
  239. else:
  240. str_e = str(e).lower()
  241. return (
  242. "not connected" in str_e or "network operation failed" in str_e
  243. )
  244. def _found_rows_client_flag(self):
  245. from asyncmy.constants import CLIENT
  246. return CLIENT.FOUND_ROWS
  247. def get_driver_connection(self, connection):
  248. return connection._connection
  249. dialect = MySQLDialect_asyncmy