cursor.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright Amethyst Reese
  2. # Licensed under the MIT license
  3. import sqlite3
  4. from collections.abc import AsyncIterator, Iterable
  5. from typing import Any, Callable, Optional, TYPE_CHECKING
  6. if TYPE_CHECKING:
  7. from .core import Connection
  8. class Cursor:
  9. def __init__(self, conn: "Connection", cursor: sqlite3.Cursor) -> None:
  10. self.iter_chunk_size = conn._iter_chunk_size
  11. self._conn = conn
  12. self._cursor = cursor
  13. def __aiter__(self) -> AsyncIterator[sqlite3.Row]:
  14. """The cursor proxy is also an async iterator."""
  15. return self._fetch_chunked()
  16. async def _fetch_chunked(self):
  17. while True:
  18. rows = await self.fetchmany(self.iter_chunk_size)
  19. if not rows:
  20. return
  21. for row in rows:
  22. yield row
  23. async def _execute(self, fn, *args, **kwargs):
  24. """Execute the given function on the shared connection's thread."""
  25. return await self._conn._execute(fn, *args, **kwargs)
  26. async def execute(
  27. self, sql: str, parameters: Optional[Iterable[Any]] = None
  28. ) -> "Cursor":
  29. """Execute the given query."""
  30. if parameters is None:
  31. parameters = []
  32. await self._execute(self._cursor.execute, sql, parameters)
  33. return self
  34. async def executemany(
  35. self, sql: str, parameters: Iterable[Iterable[Any]]
  36. ) -> "Cursor":
  37. """Execute the given multiquery."""
  38. await self._execute(self._cursor.executemany, sql, parameters)
  39. return self
  40. async def executescript(self, sql_script: str) -> "Cursor":
  41. """Execute a user script."""
  42. await self._execute(self._cursor.executescript, sql_script)
  43. return self
  44. async def fetchone(self) -> Optional[sqlite3.Row]:
  45. """Fetch a single row."""
  46. return await self._execute(self._cursor.fetchone)
  47. async def fetchmany(self, size: Optional[int] = None) -> Iterable[sqlite3.Row]:
  48. """Fetch up to `cursor.arraysize` number of rows."""
  49. args: tuple[int, ...] = ()
  50. if size is not None:
  51. args = (size,)
  52. return await self._execute(self._cursor.fetchmany, *args)
  53. async def fetchall(self) -> Iterable[sqlite3.Row]:
  54. """Fetch all remaining rows."""
  55. return await self._execute(self._cursor.fetchall)
  56. async def close(self) -> None:
  57. """Close the cursor."""
  58. await self._execute(self._cursor.close)
  59. @property
  60. def rowcount(self) -> int:
  61. return self._cursor.rowcount
  62. @property
  63. def lastrowid(self) -> Optional[int]:
  64. return self._cursor.lastrowid
  65. @property
  66. def arraysize(self) -> int:
  67. return self._cursor.arraysize
  68. @arraysize.setter
  69. def arraysize(self, value: int) -> None:
  70. self._cursor.arraysize = value
  71. @property
  72. def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...]:
  73. return self._cursor.description
  74. @property
  75. def row_factory(self) -> Optional[Callable[[sqlite3.Cursor, sqlite3.Row], object]]:
  76. return self._cursor.row_factory
  77. @row_factory.setter
  78. def row_factory(self, factory: Optional[type]) -> None:
  79. self._cursor.row_factory = factory
  80. @property
  81. def connection(self) -> sqlite3.Connection:
  82. return self._cursor.connection
  83. async def __aenter__(self):
  84. return self
  85. async def __aexit__(self, exc_type, exc_val, exc_tb):
  86. await self.close()