123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384 |
- # Copyright Amethyst Reese
- # Licensed under the MIT license
- """
- Core implementation of aiosqlite proxies
- """
- import asyncio
- import logging
- import sqlite3
- from collections.abc import AsyncIterator, Generator, Iterable
- from functools import partial
- from pathlib import Path
- from queue import Empty, Queue, SimpleQueue
- from threading import Thread
- from typing import Any, Callable, Literal, Optional, Union
- from warnings import warn
- from .context import contextmanager
- from .cursor import Cursor
- __all__ = ["connect", "Connection", "Cursor"]
- LOG = logging.getLogger("aiosqlite")
- IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
- def set_result(fut: asyncio.Future, result: Any) -> None:
- """Set the result of a future if it hasn't been set already."""
- if not fut.done():
- fut.set_result(result)
- def set_exception(fut: asyncio.Future, e: BaseException) -> None:
- """Set the exception of a future if it hasn't been set already."""
- if not fut.done():
- fut.set_exception(e)
- _STOP_RUNNING_SENTINEL = object()
- class Connection(Thread):
- def __init__(
- self,
- connector: Callable[[], sqlite3.Connection],
- iter_chunk_size: int,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- ) -> None:
- super().__init__()
- self._running = True
- self._connection: Optional[sqlite3.Connection] = None
- self._connector = connector
- self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
- self._iter_chunk_size = iter_chunk_size
- if loop is not None:
- warn(
- "aiosqlite.Connection no longer uses the `loop` parameter",
- DeprecationWarning,
- )
- def _stop_running(self):
- self._running = False
- # PEP 661 is not accepted yet, so we cannot type a sentinel
- self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type]
- @property
- def _conn(self) -> sqlite3.Connection:
- if self._connection is None:
- raise ValueError("no active connection")
- return self._connection
- def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
- cursor = self._conn.execute(sql, parameters)
- cursor.execute("SELECT last_insert_rowid()")
- return cursor.fetchone()
- def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
- cursor = self._conn.execute(sql, parameters)
- return cursor.fetchall()
- def run(self) -> None:
- """
- Execute function calls on a separate thread.
- :meta private:
- """
- while True:
- # Continues running until all queue items are processed,
- # even after connection is closed (so we can finalize all
- # futures)
- tx_item = self._tx.get()
- if tx_item is _STOP_RUNNING_SENTINEL:
- break
- future, function = tx_item
- try:
- LOG.debug("executing %s", function)
- result = function()
- LOG.debug("operation %s completed", function)
- future.get_loop().call_soon_threadsafe(set_result, future, result)
- except BaseException as e: # noqa B036
- LOG.debug("returning exception %s", e)
- future.get_loop().call_soon_threadsafe(set_exception, future, e)
- async def _execute(self, fn, *args, **kwargs):
- """Queue a function with the given arguments for execution."""
- if not self._running or not self._connection:
- raise ValueError("Connection closed")
- function = partial(fn, *args, **kwargs)
- future = asyncio.get_event_loop().create_future()
- self._tx.put_nowait((future, function))
- return await future
- async def _connect(self) -> "Connection":
- """Connect to the actual sqlite database."""
- if self._connection is None:
- try:
- future = asyncio.get_event_loop().create_future()
- self._tx.put_nowait((future, self._connector))
- self._connection = await future
- except BaseException:
- self._stop_running()
- self._connection = None
- raise
- return self
- def __await__(self) -> Generator[Any, None, "Connection"]:
- self.start()
- return self._connect().__await__()
- async def __aenter__(self) -> "Connection":
- return await self
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
- await self.close()
- @contextmanager
- async def cursor(self) -> Cursor:
- """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
- return Cursor(self, await self._execute(self._conn.cursor))
- async def commit(self) -> None:
- """Commit the current transaction."""
- await self._execute(self._conn.commit)
- async def rollback(self) -> None:
- """Roll back the current transaction."""
- await self._execute(self._conn.rollback)
- async def close(self) -> None:
- """Complete queued queries/cursors and close the connection."""
- if self._connection is None:
- return
- try:
- await self._execute(self._conn.close)
- except Exception:
- LOG.info("exception occurred while closing connection")
- raise
- finally:
- self._stop_running()
- self._connection = None
- @contextmanager
- async def execute(
- self, sql: str, parameters: Optional[Iterable[Any]] = None
- ) -> Cursor:
- """Helper to create a cursor and execute the given query."""
- if parameters is None:
- parameters = []
- cursor = await self._execute(self._conn.execute, sql, parameters)
- return Cursor(self, cursor)
- @contextmanager
- async def execute_insert(
- self, sql: str, parameters: Optional[Iterable[Any]] = None
- ) -> Optional[sqlite3.Row]:
- """Helper to insert and get the last_insert_rowid."""
- if parameters is None:
- parameters = []
- return await self._execute(self._execute_insert, sql, parameters)
- @contextmanager
- async def execute_fetchall(
- self, sql: str, parameters: Optional[Iterable[Any]] = None
- ) -> Iterable[sqlite3.Row]:
- """Helper to execute a query and return all the data."""
- if parameters is None:
- parameters = []
- return await self._execute(self._execute_fetchall, sql, parameters)
- @contextmanager
- async def executemany(
- self, sql: str, parameters: Iterable[Iterable[Any]]
- ) -> Cursor:
- """Helper to create a cursor and execute the given multiquery."""
- cursor = await self._execute(self._conn.executemany, sql, parameters)
- return Cursor(self, cursor)
- @contextmanager
- async def executescript(self, sql_script: str) -> Cursor:
- """Helper to create a cursor and execute a user script."""
- cursor = await self._execute(self._conn.executescript, sql_script)
- return Cursor(self, cursor)
- async def interrupt(self) -> None:
- """Interrupt pending queries."""
- return self._conn.interrupt()
- async def create_function(
- self, name: str, num_params: int, func: Callable, deterministic: bool = False
- ) -> None:
- """
- Create user-defined function that can be later used
- within SQL statements. Must be run within the same thread
- that query executions take place so instead of executing directly
- against the connection, we defer this to `run` function.
- If ``deterministic`` is true, the created function is marked as deterministic,
- which allows SQLite to perform additional optimizations. This flag is supported
- by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
- older versions.
- """
- await self._execute(
- self._conn.create_function,
- name,
- num_params,
- func,
- deterministic=deterministic,
- )
- @property
- def in_transaction(self) -> bool:
- return self._conn.in_transaction
- @property
- def isolation_level(self) -> Optional[str]:
- return self._conn.isolation_level
- @isolation_level.setter
- def isolation_level(self, value: IsolationLevel) -> None:
- self._conn.isolation_level = value
- @property
- def row_factory(self) -> Optional[type]:
- return self._conn.row_factory
- @row_factory.setter
- def row_factory(self, factory: Optional[type]) -> None:
- self._conn.row_factory = factory
- @property
- def text_factory(self) -> Callable[[bytes], Any]:
- return self._conn.text_factory
- @text_factory.setter
- def text_factory(self, factory: Callable[[bytes], Any]) -> None:
- self._conn.text_factory = factory
- @property
- def total_changes(self) -> int:
- return self._conn.total_changes
- async def enable_load_extension(self, value: bool) -> None:
- await self._execute(self._conn.enable_load_extension, value) # type: ignore
- async def load_extension(self, path: str):
- await self._execute(self._conn.load_extension, path) # type: ignore
- async def set_progress_handler(
- self, handler: Callable[[], Optional[int]], n: int
- ) -> None:
- await self._execute(self._conn.set_progress_handler, handler, n)
- async def set_trace_callback(self, handler: Callable) -> None:
- await self._execute(self._conn.set_trace_callback, handler)
- async def iterdump(self) -> AsyncIterator[str]:
- """
- Return an async iterator to dump the database in SQL text format.
- Example::
- async for line in db.iterdump():
- ...
- """
- dump_queue: Queue = Queue()
- def dumper():
- try:
- for line in self._conn.iterdump():
- dump_queue.put_nowait(line)
- dump_queue.put_nowait(None)
- except Exception:
- LOG.exception("exception while dumping db")
- dump_queue.put_nowait(None)
- raise
- fut = self._execute(dumper)
- task = asyncio.ensure_future(fut)
- while True:
- try:
- line: Optional[str] = dump_queue.get_nowait()
- if line is None:
- break
- yield line
- except Empty:
- if task.done():
- LOG.warning("iterdump completed unexpectedly")
- break
- await asyncio.sleep(0.01)
- await task
- async def backup(
- self,
- target: Union["Connection", sqlite3.Connection],
- *,
- pages: int = 0,
- progress: Optional[Callable[[int, int, int], None]] = None,
- name: str = "main",
- sleep: float = 0.250,
- ) -> None:
- """
- Make a backup of the current database to the target database.
- Takes either a standard sqlite3 or aiosqlite Connection object as the target.
- """
- if isinstance(target, Connection):
- target = target._conn
- await self._execute(
- self._conn.backup,
- target,
- pages=pages,
- progress=progress,
- name=name,
- sleep=sleep,
- )
- def connect(
- database: Union[str, Path],
- *,
- iter_chunk_size=64,
- loop: Optional[asyncio.AbstractEventLoop] = None,
- **kwargs: Any,
- ) -> Connection:
- """Create and return a connection proxy to the sqlite database."""
- if loop is not None:
- warn(
- "aiosqlite.connect() no longer uses the `loop` parameter",
- DeprecationWarning,
- )
- def connector() -> sqlite3.Connection:
- if isinstance(database, str):
- loc = database
- elif isinstance(database, bytes):
- loc = database.decode("utf-8")
- else:
- loc = str(database)
- return sqlite3.connect(loc, **kwargs)
- return Connection(connector, iter_chunk_size)
|