aboutsummaryrefslogtreecommitdiff
"""
Defines the 'Client' class

Importing from the `vecs.client` directly is not supported.
All public classes, enums, and functions are re-exported by the top level `vecs` module.
"""

from __future__ import annotations

import logging
import time
from typing import TYPE_CHECKING, List, Optional

import sqlalchemy
from deprecated import deprecated
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool

from .adapter import Adapter
from .exc import CollectionNotFound

if TYPE_CHECKING:
    from r2r.vecs.collection import Collection

logger = logging.getLogger(__name__)


class Client:
    """
    The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates
    the creation, retrieval, listing and deletion of vector collections, while managing connections to the
    database.

    A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
    and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.

    The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database
    is properly closed after operations, or used directly.

    Example usage:

        DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>"

        with vecs.create_client(DB_CONNECTION) as vx:
            # do some work
            pass

        # OR

        vx = vecs.create_client(DB_CONNECTION)
        # do some work
        vx.disconnect()
    """

    def __init__(
        self,
        connection_string: str,
        pool_size: int = 1,
        max_retries: int = 3,
        retry_delay: int = 1,
    ):
        self.engine = create_engine(
            connection_string,
            pool_size=pool_size,
            poolclass=QueuePool,
            pool_recycle=300,  # Recycle connections after 5 min
        )
        self.meta = MetaData(schema="vecs")
        self.Session = sessionmaker(self.engine)
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.vector_version: Optional[str] = None
        self._initialize_database()

    def _initialize_database(self):
        retries = 0
        error = None
        while retries < self.max_retries:
            try:
                with self.Session() as sess:
                    with sess.begin():
                        self._create_schema(sess)
                        self._create_extension(sess)
                        self._get_vector_version(sess)
                return
            except Exception as e:
                logger.warning(
                    f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..."
                )
                retries += 1
                time.sleep(self.retry_delay)
                error = e

        error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}"
        logger.error(error_message)
        raise RuntimeError(error_message)

    def _create_schema(self, sess):
        try:
            sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;"))
        except Exception as e:
            logger.warning(f"Failed to create schema: {str(e)}")

    def _create_extension(self, sess):
        try:
            sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
            sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
            sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"))
        except Exception as e:
            logger.warning(f"Failed to create extension: {str(e)}")

    def _get_vector_version(self, sess):
        try:
            self.vector_version = sess.execute(
                text(
                    "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
                )
            ).scalar_one()
        except sqlalchemy.exc.InternalError as e:
            logger.error(f"Failed with internal alchemy error: {str(e)}")

            import psycopg2

            if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction):
                sess.rollback()
                self.vector_version = sess.execute(
                    text(
                        "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
                    )
                ).scalar_one()
            else:
                raise e
        except Exception as e:
            logger.error(f"Failed to retrieve vector version: {str(e)}")
            raise e

    def _supports_hnsw(self):
        return (
            not self.vector_version.startswith("0.4")
            and not self.vector_version.startswith("0.3")
            and not self.vector_version.startswith("0.2")
            and not self.vector_version.startswith("0.1")
            and not self.vector_version.startswith("0.0")
        )

    def get_or_create_collection(
        self,
        name: str,
        *,
        dimension: Optional[int] = None,
        adapter: Optional[Adapter] = None,
    ) -> Collection:
        """
        Get a vector collection by name, or create it if no collection with
        *name* exists.

        Args:
            name (str): The name of the collection.

        Keyword Args:
            dimension (int): The dimensionality of the vectors in the collection.
            pipeline (int): The dimensionality of the vectors in the collection.

        Returns:
            Collection: The created collection.

        Raises:
            CollectionAlreadyExists: If a collection with the same name already exists
        """
        from r2r.vecs.collection import Collection

        adapter_dimension = adapter.exported_dimension if adapter else None

        collection = Collection(
            name=name,
            dimension=dimension or adapter_dimension,  # type: ignore
            client=self,
            adapter=adapter,
        )

        return collection._create_if_not_exists()

    @deprecated("use Client.get_or_create_collection")
    def create_collection(self, name: str, dimension: int) -> Collection:
        """
        Create a new vector collection.

        Args:
            name (str): The name of the collection.
            dimension (int): The dimensionality of the vectors in the collection.

        Returns:
            Collection: The created collection.

        Raises:
            CollectionAlreadyExists: If a collection with the same name already exists
        """
        from r2r.vecs.collection import Collection

        return Collection(name, dimension, self)._create()

    @deprecated("use Client.get_or_create_collection")
    def get_collection(self, name: str) -> Collection:
        """
        Retrieve an existing vector collection.

        Args:
            name (str): The name of the collection.

        Returns:
            Collection: The retrieved collection.

        Raises:
            CollectionNotFound: If no collection with the given name exists.
        """
        from r2r.vecs.collection import Collection

        query = text(
            f"""
        select
            relname as table_name,
            atttypmod as embedding_dim
        from
            pg_class pc
            join pg_attribute pa
                on pc.oid = pa.attrelid
        where
            pc.relnamespace = 'vecs'::regnamespace
            and pc.relkind = 'r'
            and pa.attname = 'vec'
            and not pc.relname ^@ '_'
            and pc.relname = :name
        """
        ).bindparams(name=name)
        with self.Session() as sess:
            query_result = sess.execute(query).fetchone()

            if query_result is None:
                raise CollectionNotFound(
                    "No collection found with requested name"
                )

            name, dimension = query_result
            return Collection(
                name,
                dimension,
                self,
            )

    def list_collections(self) -> List["Collection"]:
        """
        List all vector collections.

        Returns:
            list[Collection]: A list of all collections.
        """
        from r2r.vecs.collection import Collection

        return Collection._list_collections(self)

    def delete_collection(self, name: str) -> None:
        """
        Delete a vector collection.

        If no collection with requested name exists, does nothing.

        Args:
            name (str): The name of the collection.

        Returns:
            None
        """
        from r2r.vecs.collection import Collection

        Collection(name, -1, self)._drop()
        return

    def disconnect(self) -> None:
        """
        Disconnect the client from the database.

        Returns:
            None
        """
        self.engine.dispose()
        logger.info("Disconnected from the database.")
        return

    def __enter__(self) -> "Client":
        """
        Enable use of the 'with' statement.

        Returns:
            Client: The current instance of the Client.
        """

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Disconnect the client on exiting the 'with' statement context.

        Args:
            exc_type: The exception type, if any.
            exc_val: The exception value, if any.
            exc_tb: The traceback, if any.

        Returns:
            None
        """
        self.disconnect()
        return