123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright Amethyst Reese
- # Licensed under the MIT license
- import sqlite3
- from collections.abc import AsyncIterator, Iterable
- from typing import Any, Callable, Optional, TYPE_CHECKING
- if TYPE_CHECKING:
- from .core import Connection
- class Cursor:
- def __init__(self, conn: "Connection", cursor: sqlite3.Cursor) -> None:
- self.iter_chunk_size = conn._iter_chunk_size
- self._conn = conn
- self._cursor = cursor
- def __aiter__(self) -> AsyncIterator[sqlite3.Row]:
- """The cursor proxy is also an async iterator."""
- return self._fetch_chunked()
- async def _fetch_chunked(self):
- while True:
- rows = await self.fetchmany(self.iter_chunk_size)
- if not rows:
- return
- for row in rows:
- yield row
- async def _execute(self, fn, *args, **kwargs):
- """Execute the given function on the shared connection's thread."""
- return await self._conn._execute(fn, *args, **kwargs)
- async def execute(
- self, sql: str, parameters: Optional[Iterable[Any]] = None
- ) -> "Cursor":
- """Execute the given query."""
- if parameters is None:
- parameters = []
- await self._execute(self._cursor.execute, sql, parameters)
- return self
- async def executemany(
- self, sql: str, parameters: Iterable[Iterable[Any]]
- ) -> "Cursor":
- """Execute the given multiquery."""
- await self._execute(self._cursor.executemany, sql, parameters)
- return self
- async def executescript(self, sql_script: str) -> "Cursor":
- """Execute a user script."""
- await self._execute(self._cursor.executescript, sql_script)
- return self
- async def fetchone(self) -> Optional[sqlite3.Row]:
- """Fetch a single row."""
- return await self._execute(self._cursor.fetchone)
- async def fetchmany(self, size: Optional[int] = None) -> Iterable[sqlite3.Row]:
- """Fetch up to `cursor.arraysize` number of rows."""
- args: tuple[int, ...] = ()
- if size is not None:
- args = (size,)
- return await self._execute(self._cursor.fetchmany, *args)
- async def fetchall(self) -> Iterable[sqlite3.Row]:
- """Fetch all remaining rows."""
- return await self._execute(self._cursor.fetchall)
- async def close(self) -> None:
- """Close the cursor."""
- await self._execute(self._cursor.close)
- @property
- def rowcount(self) -> int:
- return self._cursor.rowcount
- @property
- def lastrowid(self) -> Optional[int]:
- return self._cursor.lastrowid
- @property
- def arraysize(self) -> int:
- return self._cursor.arraysize
- @arraysize.setter
- def arraysize(self, value: int) -> None:
- self._cursor.arraysize = value
- @property
- def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...]:
- return self._cursor.description
- @property
- def row_factory(self) -> Optional[Callable[[sqlite3.Cursor, sqlite3.Row], object]]:
- return self._cursor.row_factory
- @row_factory.setter
- def row_factory(self, factory: Optional[type]) -> None:
- self._cursor.row_factory = factory
- @property
- def connection(self) -> sqlite3.Connection:
- return self._cursor.connection
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self.close()
|