about summary refs log tree commit diff
path: root/R2R/r2r/vecs/client.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/vecs/client.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/vecs/client.py')
-rwxr-xr-xR2R/r2r/vecs/client.py313
1 files changed, 313 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/client.py b/R2R/r2r/vecs/client.py
new file mode 100755
index 00000000..6259f1d8
--- /dev/null
+++ b/R2R/r2r/vecs/client.py
@@ -0,0 +1,313 @@
+"""
+Defines the 'Client' class
+
+Importing from the `vecs.client` directly is not supported.
+All public classes, enums, and functions are re-exported by the top level `vecs` module.
+"""
+
+from __future__ import annotations
+
+import logging
+import time
+from typing import TYPE_CHECKING, List, Optional
+
+import sqlalchemy
+from deprecated import deprecated
+from sqlalchemy import MetaData, create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import QueuePool
+
+from .adapter import Adapter
+from .exc import CollectionNotFound
+
+if TYPE_CHECKING:
+    from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class Client:
+    """
+    The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates
+    the creation, retrieval, listing and deletion of vector collections, while managing connections to the
+    database.
+
+    A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
+    and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.
+
+    The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database
+    is properly closed after operations, or used directly.
+
+    Example usage:
+
+        DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>"
+
+        with vecs.create_client(DB_CONNECTION) as vx:
+            # do some work
+            pass
+
+        # OR
+
+        vx = vecs.create_client(DB_CONNECTION)
+        # do some work
+        vx.disconnect()
+    """
+
+    def __init__(
+        self,
+        connection_string: str,
+        pool_size: int = 1,
+        max_retries: int = 3,
+        retry_delay: int = 1,
+    ):
+        self.engine = create_engine(
+            connection_string,
+            pool_size=pool_size,
+            poolclass=QueuePool,
+            pool_recycle=300,  # Recycle connections after 5 min
+        )
+        self.meta = MetaData(schema="vecs")
+        self.Session = sessionmaker(self.engine)
+        self.max_retries = max_retries
+        self.retry_delay = retry_delay
+        self.vector_version: Optional[str] = None
+        self._initialize_database()
+
+    def _initialize_database(self):
+        retries = 0
+        error = None
+        while retries < self.max_retries:
+            try:
+                with self.Session() as sess:
+                    with sess.begin():
+                        self._create_schema(sess)
+                        self._create_extension(sess)
+                        self._get_vector_version(sess)
+                return
+            except Exception as e:
+                logger.warning(
+                    f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..."
+                )
+                retries += 1
+                time.sleep(self.retry_delay)
+                error = e
+
+        error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}"
+        logger.error(error_message)
+        raise RuntimeError(error_message)
+
+    def _create_schema(self, sess):
+        try:
+            sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;"))
+        except Exception as e:
+            logger.warning(f"Failed to create schema: {str(e)}")
+
+    def _create_extension(self, sess):
+        try:
+            sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+            sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
+            sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"))
+        except Exception as e:
+            logger.warning(f"Failed to create extension: {str(e)}")
+
+    def _get_vector_version(self, sess):
+        try:
+            self.vector_version = sess.execute(
+                text(
+                    "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+                )
+            ).scalar_one()
+        except sqlalchemy.exc.InternalError as e:
+            logger.error(f"Failed with internal alchemy error: {str(e)}")
+
+            import psycopg2
+
+            if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction):
+                sess.rollback()
+                self.vector_version = sess.execute(
+                    text(
+                        "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+                    )
+                ).scalar_one()
+            else:
+                raise e
+        except Exception as e:
+            logger.error(f"Failed to retrieve vector version: {str(e)}")
+            raise e
+
+    def _supports_hnsw(self):
+        return (
+            not self.vector_version.startswith("0.4")
+            and not self.vector_version.startswith("0.3")
+            and not self.vector_version.startswith("0.2")
+            and not self.vector_version.startswith("0.1")
+            and not self.vector_version.startswith("0.0")
+        )
+
+    def get_or_create_collection(
+        self,
+        name: str,
+        *,
+        dimension: Optional[int] = None,
+        adapter: Optional[Adapter] = None,
+    ) -> Collection:
+        """
+        Get a vector collection by name, or create it if no collection with
+        *name* exists.
+
+        Args:
+            name (str): The name of the collection.
+
+        Keyword Args:
+            dimension (int): The dimensionality of the vectors in the collection.
+            pipeline (int): The dimensionality of the vectors in the collection.
+
+        Returns:
+            Collection: The created collection.
+
+        Raises:
+            CollectionAlreadyExists: If a collection with the same name already exists
+        """
+        from r2r.vecs.collection import Collection
+
+        adapter_dimension = adapter.exported_dimension if adapter else None
+
+        collection = Collection(
+            name=name,
+            dimension=dimension or adapter_dimension,  # type: ignore
+            client=self,
+            adapter=adapter,
+        )
+
+        return collection._create_if_not_exists()
+
+    @deprecated("use Client.get_or_create_collection")
+    def create_collection(self, name: str, dimension: int) -> Collection:
+        """
+        Create a new vector collection.
+
+        Args:
+            name (str): The name of the collection.
+            dimension (int): The dimensionality of the vectors in the collection.
+
+        Returns:
+            Collection: The created collection.
+
+        Raises:
+            CollectionAlreadyExists: If a collection with the same name already exists
+        """
+        from r2r.vecs.collection import Collection
+
+        return Collection(name, dimension, self)._create()
+
+    @deprecated("use Client.get_or_create_collection")
+    def get_collection(self, name: str) -> Collection:
+        """
+        Retrieve an existing vector collection.
+
+        Args:
+            name (str): The name of the collection.
+
+        Returns:
+            Collection: The retrieved collection.
+
+        Raises:
+            CollectionNotFound: If no collection with the given name exists.
+        """
+        from r2r.vecs.collection import Collection
+
+        query = text(
+            f"""
+        select
+            relname as table_name,
+            atttypmod as embedding_dim
+        from
+            pg_class pc
+            join pg_attribute pa
+                on pc.oid = pa.attrelid
+        where
+            pc.relnamespace = 'vecs'::regnamespace
+            and pc.relkind = 'r'
+            and pa.attname = 'vec'
+            and not pc.relname ^@ '_'
+            and pc.relname = :name
+        """
+        ).bindparams(name=name)
+        with self.Session() as sess:
+            query_result = sess.execute(query).fetchone()
+
+            if query_result is None:
+                raise CollectionNotFound(
+                    "No collection found with requested name"
+                )
+
+            name, dimension = query_result
+            return Collection(
+                name,
+                dimension,
+                self,
+            )
+
+    def list_collections(self) -> List["Collection"]:
+        """
+        List all vector collections.
+
+        Returns:
+            list[Collection]: A list of all collections.
+        """
+        from r2r.vecs.collection import Collection
+
+        return Collection._list_collections(self)
+
+    def delete_collection(self, name: str) -> None:
+        """
+        Delete a vector collection.
+
+        If no collection with requested name exists, does nothing.
+
+        Args:
+            name (str): The name of the collection.
+
+        Returns:
+            None
+        """
+        from r2r.vecs.collection import Collection
+
+        Collection(name, -1, self)._drop()
+        return
+
+    def disconnect(self) -> None:
+        """
+        Disconnect the client from the database.
+
+        Returns:
+            None
+        """
+        self.engine.dispose()
+        logger.info("Disconnected from the database.")
+        return
+
+    def __enter__(self) -> "Client":
+        """
+        Enable use of the 'with' statement.
+
+        Returns:
+            Client: The current instance of the Client.
+        """
+
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """
+        Disconnect the client on exiting the 'with' statement context.
+
+        Args:
+            exc_type: The exception type, if any.
+            exc_val: The exception value, if any.
+            exc_tb: The traceback, if any.
+
+        Returns:
+            None
+        """
+        self.disconnect()
+        return