about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiosqlite/core.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiosqlite/core.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiosqlite/core.py394
1 files changed, 394 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiosqlite/core.py b/.venv/lib/python3.12/site-packages/aiosqlite/core.py
new file mode 100644
index 00000000..58c3fece
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiosqlite/core.py
@@ -0,0 +1,394 @@
+# Copyright 2022 Amethyst Reese
+# Licensed under the MIT license
+
+"""
+Core implementation of aiosqlite proxies
+"""
+
+import asyncio
+import logging
+import sqlite3
+from functools import partial
+from pathlib import Path
+from queue import Empty, Queue, SimpleQueue
+from threading import Thread
+from typing import (
+    Any,
+    AsyncIterator,
+    Callable,
+    Generator,
+    Iterable,
+    Literal,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
+from warnings import warn
+
+from .context import contextmanager
+from .cursor import Cursor
+
+__all__ = ["connect", "Connection", "Cursor"]
+
+LOG = logging.getLogger("aiosqlite")
+
+
+IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
+
+
+def set_result(fut: asyncio.Future, result: Any) -> None:
+    """Set the result of a future if it hasn't been set already."""
+    if not fut.done():
+        fut.set_result(result)
+
+
+def set_exception(fut: asyncio.Future, e: BaseException) -> None:
+    """Set the exception of a future if it hasn't been set already."""
+    if not fut.done():
+        fut.set_exception(e)
+
+
+_STOP_RUNNING_SENTINEL = object()
+
+
+class Connection(Thread):
+    def __init__(
+        self,
+        connector: Callable[[], sqlite3.Connection],
+        iter_chunk_size: int,
+        loop: Optional[asyncio.AbstractEventLoop] = None,
+    ) -> None:
+        super().__init__()
+        self._running = True
+        self._connection: Optional[sqlite3.Connection] = None
+        self._connector = connector
+        self._tx: SimpleQueue[Tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
+        self._iter_chunk_size = iter_chunk_size
+
+        if loop is not None:
+            warn(
+                "aiosqlite.Connection no longer uses the `loop` parameter",
+                DeprecationWarning,
+            )
+
+    def _stop_running(self):
+        self._running = False
+        # PEP 661 is not accepted yet, so we cannot type a sentinel
+        self._tx.put_nowait(_STOP_RUNNING_SENTINEL)  # type: ignore[arg-type]
+
+    @property
+    def _conn(self) -> sqlite3.Connection:
+        if self._connection is None:
+            raise ValueError("no active connection")
+
+        return self._connection
+
+    def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
+        cursor = self._conn.execute(sql, parameters)
+        cursor.execute("SELECT last_insert_rowid()")
+        return cursor.fetchone()
+
+    def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
+        cursor = self._conn.execute(sql, parameters)
+        return cursor.fetchall()
+
+    def run(self) -> None:
+        """
+        Execute function calls on a separate thread.
+
+        :meta private:
+        """
+        while True:
+            # Continues running until all queue items are processed,
+            # even after connection is closed (so we can finalize all
+            # futures)
+
+            tx_item = self._tx.get()
+            if tx_item is _STOP_RUNNING_SENTINEL:
+                break
+
+            future, function = tx_item
+
+            try:
+                LOG.debug("executing %s", function)
+                result = function()
+                LOG.debug("operation %s completed", function)
+                future.get_loop().call_soon_threadsafe(set_result, future, result)
+            except BaseException as e:  # noqa B036
+                LOG.debug("returning exception %s", e)
+                future.get_loop().call_soon_threadsafe(set_exception, future, e)
+
+    async def _execute(self, fn, *args, **kwargs):
+        """Queue a function with the given arguments for execution."""
+        if not self._running or not self._connection:
+            raise ValueError("Connection closed")
+
+        function = partial(fn, *args, **kwargs)
+        future = asyncio.get_event_loop().create_future()
+
+        self._tx.put_nowait((future, function))
+
+        return await future
+
+    async def _connect(self) -> "Connection":
+        """Connect to the actual sqlite database."""
+        if self._connection is None:
+            try:
+                future = asyncio.get_event_loop().create_future()
+                self._tx.put_nowait((future, self._connector))
+                self._connection = await future
+            except Exception:
+                self._stop_running()
+                self._connection = None
+                raise
+
+        return self
+
+    def __await__(self) -> Generator[Any, None, "Connection"]:
+        self.start()
+        return self._connect().__await__()
+
+    async def __aenter__(self) -> "Connection":
+        return await self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
+        await self.close()
+
+    @contextmanager
+    async def cursor(self) -> Cursor:
+        """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
+        return Cursor(self, await self._execute(self._conn.cursor))
+
+    async def commit(self) -> None:
+        """Commit the current transaction."""
+        await self._execute(self._conn.commit)
+
+    async def rollback(self) -> None:
+        """Roll back the current transaction."""
+        await self._execute(self._conn.rollback)
+
+    async def close(self) -> None:
+        """Complete queued queries/cursors and close the connection."""
+
+        if self._connection is None:
+            return
+
+        try:
+            await self._execute(self._conn.close)
+        except Exception:
+            LOG.info("exception occurred while closing connection")
+            raise
+        finally:
+            self._stop_running()
+            self._connection = None
+
+    @contextmanager
+    async def execute(
+        self, sql: str, parameters: Optional[Iterable[Any]] = None
+    ) -> Cursor:
+        """Helper to create a cursor and execute the given query."""
+        if parameters is None:
+            parameters = []
+        cursor = await self._execute(self._conn.execute, sql, parameters)
+        return Cursor(self, cursor)
+
+    @contextmanager
+    async def execute_insert(
+        self, sql: str, parameters: Optional[Iterable[Any]] = None
+    ) -> Optional[sqlite3.Row]:
+        """Helper to insert and get the last_insert_rowid."""
+        if parameters is None:
+            parameters = []
+        return await self._execute(self._execute_insert, sql, parameters)
+
+    @contextmanager
+    async def execute_fetchall(
+        self, sql: str, parameters: Optional[Iterable[Any]] = None
+    ) -> Iterable[sqlite3.Row]:
+        """Helper to execute a query and return all the data."""
+        if parameters is None:
+            parameters = []
+        return await self._execute(self._execute_fetchall, sql, parameters)
+
+    @contextmanager
+    async def executemany(
+        self, sql: str, parameters: Iterable[Iterable[Any]]
+    ) -> Cursor:
+        """Helper to create a cursor and execute the given multiquery."""
+        cursor = await self._execute(self._conn.executemany, sql, parameters)
+        return Cursor(self, cursor)
+
+    @contextmanager
+    async def executescript(self, sql_script: str) -> Cursor:
+        """Helper to create a cursor and execute a user script."""
+        cursor = await self._execute(self._conn.executescript, sql_script)
+        return Cursor(self, cursor)
+
+    async def interrupt(self) -> None:
+        """Interrupt pending queries."""
+        return self._conn.interrupt()
+
+    async def create_function(
+        self, name: str, num_params: int, func: Callable, deterministic: bool = False
+    ) -> None:
+        """
+        Create user-defined function that can be later used
+        within SQL statements. Must be run within the same thread
+        that query executions take place so instead of executing directly
+        against the connection, we defer this to `run` function.
+
+        If ``deterministic`` is true, the created function is marked as deterministic,
+        which allows SQLite to perform additional optimizations. This flag is supported
+        by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
+        older versions.
+        """
+        await self._execute(
+            self._conn.create_function,
+            name,
+            num_params,
+            func,
+            deterministic=deterministic,
+        )
+
+    @property
+    def in_transaction(self) -> bool:
+        return self._conn.in_transaction
+
+    @property
+    def isolation_level(self) -> Optional[str]:
+        return self._conn.isolation_level
+
+    @isolation_level.setter
+    def isolation_level(self, value: IsolationLevel) -> None:
+        self._conn.isolation_level = value
+
+    @property
+    def row_factory(self) -> Optional[Type]:
+        return self._conn.row_factory
+
+    @row_factory.setter
+    def row_factory(self, factory: Optional[Type]) -> None:
+        self._conn.row_factory = factory
+
+    @property
+    def text_factory(self) -> Callable[[bytes], Any]:
+        return self._conn.text_factory
+
+    @text_factory.setter
+    def text_factory(self, factory: Callable[[bytes], Any]) -> None:
+        self._conn.text_factory = factory
+
+    @property
+    def total_changes(self) -> int:
+        return self._conn.total_changes
+
+    async def enable_load_extension(self, value: bool) -> None:
+        await self._execute(self._conn.enable_load_extension, value)  # type: ignore
+
+    async def load_extension(self, path: str):
+        await self._execute(self._conn.load_extension, path)  # type: ignore
+
+    async def set_progress_handler(
+        self, handler: Callable[[], Optional[int]], n: int
+    ) -> None:
+        await self._execute(self._conn.set_progress_handler, handler, n)
+
+    async def set_trace_callback(self, handler: Callable) -> None:
+        await self._execute(self._conn.set_trace_callback, handler)
+
+    async def iterdump(self) -> AsyncIterator[str]:
+        """
+        Return an async iterator to dump the database in SQL text format.
+
+        Example::
+
+            async for line in db.iterdump():
+                ...
+
+        """
+        dump_queue: Queue = Queue()
+
+        def dumper():
+            try:
+                for line in self._conn.iterdump():
+                    dump_queue.put_nowait(line)
+                dump_queue.put_nowait(None)
+
+            except Exception:
+                LOG.exception("exception while dumping db")
+                dump_queue.put_nowait(None)
+                raise
+
+        fut = self._execute(dumper)
+        task = asyncio.ensure_future(fut)
+
+        while True:
+            try:
+                line: Optional[str] = dump_queue.get_nowait()
+                if line is None:
+                    break
+                yield line
+
+            except Empty:
+                if task.done():
+                    LOG.warning("iterdump completed unexpectedly")
+                    break
+
+                await asyncio.sleep(0.01)
+
+        await task
+
+    async def backup(
+        self,
+        target: Union["Connection", sqlite3.Connection],
+        *,
+        pages: int = 0,
+        progress: Optional[Callable[[int, int, int], None]] = None,
+        name: str = "main",
+        sleep: float = 0.250,
+    ) -> None:
+        """
+        Make a backup of the current database to the target database.
+
+        Takes either a standard sqlite3 or aiosqlite Connection object as the target.
+        """
+        if isinstance(target, Connection):
+            target = target._conn
+
+        await self._execute(
+            self._conn.backup,
+            target,
+            pages=pages,
+            progress=progress,
+            name=name,
+            sleep=sleep,
+        )
+
+
+def connect(
+    database: Union[str, Path],
+    *,
+    iter_chunk_size=64,
+    loop: Optional[asyncio.AbstractEventLoop] = None,
+    **kwargs: Any,
+) -> Connection:
+    """Create and return a connection proxy to the sqlite database."""
+
+    if loop is not None:
+        warn(
+            "aiosqlite.connect() no longer uses the `loop` parameter",
+            DeprecationWarning,
+        )
+
+    def connector() -> sqlite3.Connection:
+        if isinstance(database, str):
+            loc = database
+        elif isinstance(database, bytes):
+            loc = database.decode("utf-8")
+        else:
+            loc = str(database)
+
+        return sqlite3.connect(loc, **kwargs)
+
+    return Connection(connector, iter_chunk_size)