aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/vecs/client.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/vecs/client.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/vecs/client.py')
-rwxr-xr-xR2R/r2r/vecs/client.py313
1 files changed, 313 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/client.py b/R2R/r2r/vecs/client.py
new file mode 100755
index 00000000..6259f1d8
--- /dev/null
+++ b/R2R/r2r/vecs/client.py
@@ -0,0 +1,313 @@
+"""
+Defines the 'Client' class
+
+Importing from the `vecs.client` directly is not supported.
+All public classes, enums, and functions are re-exported by the top level `vecs` module.
+"""
+
+from __future__ import annotations
+
+import logging
+import time
+from typing import TYPE_CHECKING, List, Optional
+
+import sqlalchemy
+from deprecated import deprecated
+from sqlalchemy import MetaData, create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import QueuePool
+
+from .adapter import Adapter
+from .exc import CollectionNotFound
+
+if TYPE_CHECKING:
+ from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class Client:
+ """
+ The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates
+ the creation, retrieval, listing and deletion of vector collections, while managing connections to the
+ database.
+
+ A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
+ and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.
+
+ The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database
+ is properly closed after operations, or used directly.
+
+ Example usage:
+
+ DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>"
+
+ with vecs.create_client(DB_CONNECTION) as vx:
+ # do some work
+ pass
+
+ # OR
+
+ vx = vecs.create_client(DB_CONNECTION)
+ # do some work
+ vx.disconnect()
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ pool_size: int = 1,
+ max_retries: int = 3,
+ retry_delay: int = 1,
+ ):
+ self.engine = create_engine(
+ connection_string,
+ pool_size=pool_size,
+ poolclass=QueuePool,
+ pool_recycle=300, # Recycle connections after 5 min
+ )
+ self.meta = MetaData(schema="vecs")
+ self.Session = sessionmaker(self.engine)
+ self.max_retries = max_retries
+ self.retry_delay = retry_delay
+ self.vector_version: Optional[str] = None
+ self._initialize_database()
+
+ def _initialize_database(self):
+ retries = 0
+ error = None
+ while retries < self.max_retries:
+ try:
+ with self.Session() as sess:
+ with sess.begin():
+ self._create_schema(sess)
+ self._create_extension(sess)
+ self._get_vector_version(sess)
+ return
+ except Exception as e:
+ logger.warning(
+ f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..."
+ )
+ retries += 1
+ time.sleep(self.retry_delay)
+ error = e
+
+ error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}"
+ logger.error(error_message)
+ raise RuntimeError(error_message)
+
+ def _create_schema(self, sess):
+ try:
+ sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;"))
+ except Exception as e:
+ logger.warning(f"Failed to create schema: {str(e)}")
+
+ def _create_extension(self, sess):
+ try:
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"))
+ except Exception as e:
+ logger.warning(f"Failed to create extension: {str(e)}")
+
+ def _get_vector_version(self, sess):
+ try:
+ self.vector_version = sess.execute(
+ text(
+ "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+ )
+ ).scalar_one()
+ except sqlalchemy.exc.InternalError as e:
+ logger.error(f"Failed with internal alchemy error: {str(e)}")
+
+ import psycopg2
+
+ if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction):
+ sess.rollback()
+ self.vector_version = sess.execute(
+ text(
+ "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+ )
+ ).scalar_one()
+ else:
+ raise e
+ except Exception as e:
+ logger.error(f"Failed to retrieve vector version: {str(e)}")
+ raise e
+
+ def _supports_hnsw(self):
+ return (
+ not self.vector_version.startswith("0.4")
+ and not self.vector_version.startswith("0.3")
+ and not self.vector_version.startswith("0.2")
+ and not self.vector_version.startswith("0.1")
+ and not self.vector_version.startswith("0.0")
+ )
+
+ def get_or_create_collection(
+ self,
+ name: str,
+ *,
+ dimension: Optional[int] = None,
+ adapter: Optional[Adapter] = None,
+ ) -> Collection:
+ """
+ Get a vector collection by name, or create it if no collection with
+ *name* exists.
+
+ Args:
+ name (str): The name of the collection.
+
+ Keyword Args:
+ dimension (int): The dimensionality of the vectors in the collection.
+ pipeline (int): The dimensionality of the vectors in the collection.
+
+ Returns:
+ Collection: The created collection.
+
+ Raises:
+ CollectionAlreadyExists: If a collection with the same name already exists
+ """
+ from r2r.vecs.collection import Collection
+
+ adapter_dimension = adapter.exported_dimension if adapter else None
+
+ collection = Collection(
+ name=name,
+ dimension=dimension or adapter_dimension, # type: ignore
+ client=self,
+ adapter=adapter,
+ )
+
+ return collection._create_if_not_exists()
+
+ @deprecated("use Client.get_or_create_collection")
+ def create_collection(self, name: str, dimension: int) -> Collection:
+ """
+ Create a new vector collection.
+
+ Args:
+ name (str): The name of the collection.
+ dimension (int): The dimensionality of the vectors in the collection.
+
+ Returns:
+ Collection: The created collection.
+
+ Raises:
+ CollectionAlreadyExists: If a collection with the same name already exists
+ """
+ from r2r.vecs.collection import Collection
+
+ return Collection(name, dimension, self)._create()
+
+ @deprecated("use Client.get_or_create_collection")
+ def get_collection(self, name: str) -> Collection:
+ """
+ Retrieve an existing vector collection.
+
+ Args:
+ name (str): The name of the collection.
+
+ Returns:
+ Collection: The retrieved collection.
+
+ Raises:
+ CollectionNotFound: If no collection with the given name exists.
+ """
+ from r2r.vecs.collection import Collection
+
+ query = text(
+ f"""
+ select
+ relname as table_name,
+ atttypmod as embedding_dim
+ from
+ pg_class pc
+ join pg_attribute pa
+ on pc.oid = pa.attrelid
+ where
+ pc.relnamespace = 'vecs'::regnamespace
+ and pc.relkind = 'r'
+ and pa.attname = 'vec'
+ and not pc.relname ^@ '_'
+ and pc.relname = :name
+ """
+ ).bindparams(name=name)
+ with self.Session() as sess:
+ query_result = sess.execute(query).fetchone()
+
+ if query_result is None:
+ raise CollectionNotFound(
+ "No collection found with requested name"
+ )
+
+ name, dimension = query_result
+ return Collection(
+ name,
+ dimension,
+ self,
+ )
+
+ def list_collections(self) -> List["Collection"]:
+ """
+ List all vector collections.
+
+ Returns:
+ list[Collection]: A list of all collections.
+ """
+ from r2r.vecs.collection import Collection
+
+ return Collection._list_collections(self)
+
+ def delete_collection(self, name: str) -> None:
+ """
+ Delete a vector collection.
+
+ If no collection with requested name exists, does nothing.
+
+ Args:
+ name (str): The name of the collection.
+
+ Returns:
+ None
+ """
+ from r2r.vecs.collection import Collection
+
+ Collection(name, -1, self)._drop()
+ return
+
+ def disconnect(self) -> None:
+ """
+ Disconnect the client from the database.
+
+ Returns:
+ None
+ """
+ self.engine.dispose()
+ logger.info("Disconnected from the database.")
+ return
+
+ def __enter__(self) -> "Client":
+ """
+ Enable use of the 'with' statement.
+
+ Returns:
+ Client: The current instance of the Client.
+ """
+
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Disconnect the client on exiting the 'with' statement context.
+
+ Args:
+ exc_type: The exception type, if any.
+ exc_val: The exception value, if any.
+ exc_tb: The traceback, if any.
+
+ Returns:
+ None
+ """
+ self.disconnect()
+ return