"""
Defines the 'Client' class
Importing from the `vecs.client` directly is not supported.
All public classes, enums, and functions are re-exported by the top level `vecs` module.
"""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING, List, Optional
import sqlalchemy
from deprecated import deprecated
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from .adapter import Adapter
from .exc import CollectionNotFound
if TYPE_CHECKING:
from r2r.vecs.collection import Collection
logger = logging.getLogger(__name__)
class Client:
"""
The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates
the creation, retrieval, listing and deletion of vector collections, while managing connections to the
database.
A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.
The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database
is properly closed after operations, or used directly.
Example usage:
DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>"
with vecs.create_client(DB_CONNECTION) as vx:
# do some work
pass
# OR
vx = vecs.create_client(DB_CONNECTION)
# do some work
vx.disconnect()
"""
def __init__(
self,
connection_string: str,
pool_size: int = 1,
max_retries: int = 3,
retry_delay: int = 1,
):
self.engine = create_engine(
connection_string,
pool_size=pool_size,
poolclass=QueuePool,
pool_recycle=300, # Recycle connections after 5 min
)
self.meta = MetaData(schema="vecs")
self.Session = sessionmaker(self.engine)
self.max_retries = max_retries
self.retry_delay = retry_delay
self.vector_version: Optional[str] = None
self._initialize_database()
def _initialize_database(self):
retries = 0
error = None
while retries < self.max_retries:
try:
with self.Session() as sess:
with sess.begin():
self._create_schema(sess)
self._create_extension(sess)
self._get_vector_version(sess)
return
except Exception as e:
logger.warning(
f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..."
)
retries += 1
time.sleep(self.retry_delay)
error = e
error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}"
logger.error(error_message)
raise RuntimeError(error_message)
def _create_schema(self, sess):
try:
sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;"))
except Exception as e:
logger.warning(f"Failed to create schema: {str(e)}")
def _create_extension(self, sess):
try:
sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"))
except Exception as e:
logger.warning(f"Failed to create extension: {str(e)}")
def _get_vector_version(self, sess):
try:
self.vector_version = sess.execute(
text(
"SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
)
).scalar_one()
except sqlalchemy.exc.InternalError as e:
logger.error(f"Failed with internal alchemy error: {str(e)}")
import psycopg2
if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction):
sess.rollback()
self.vector_version = sess.execute(
text(
"SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
)
).scalar_one()
else:
raise e
except Exception as e:
logger.error(f"Failed to retrieve vector version: {str(e)}")
raise e
def _supports_hnsw(self):
return (
not self.vector_version.startswith("0.4")
and not self.vector_version.startswith("0.3")
and not self.vector_version.startswith("0.2")
and not self.vector_version.startswith("0.1")
and not self.vector_version.startswith("0.0")
)
def get_or_create_collection(
self,
name: str,
*,
dimension: Optional[int] = None,
adapter: Optional[Adapter] = None,
) -> Collection:
"""
Get a vector collection by name, or create it if no collection with
*name* exists.
Args:
name (str): The name of the collection.
Keyword Args:
dimension (int): The dimensionality of the vectors in the collection.
pipeline (int): The dimensionality of the vectors in the collection.
Returns:
Collection: The created collection.
Raises:
CollectionAlreadyExists: If a collection with the same name already exists
"""
from r2r.vecs.collection import Collection
adapter_dimension = adapter.exported_dimension if adapter else None
collection = Collection(
name=name,
dimension=dimension or adapter_dimension, # type: ignore
client=self,
adapter=adapter,
)
return collection._create_if_not_exists()
@deprecated("use Client.get_or_create_collection")
def create_collection(self, name: str, dimension: int) -> Collection:
"""
Create a new vector collection.
Args:
name (str): The name of the collection.
dimension (int): The dimensionality of the vectors in the collection.
Returns:
Collection: The created collection.
Raises:
CollectionAlreadyExists: If a collection with the same name already exists
"""
from r2r.vecs.collection import Collection
return Collection(name, dimension, self)._create()
@deprecated("use Client.get_or_create_collection")
def get_collection(self, name: str) -> Collection:
"""
Retrieve an existing vector collection.
Args:
name (str): The name of the collection.
Returns:
Collection: The retrieved collection.
Raises:
CollectionNotFound: If no collection with the given name exists.
"""
from r2r.vecs.collection import Collection
query = text(
f"""
select
relname as table_name,
atttypmod as embedding_dim
from
pg_class pc
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
and pc.relname = :name
"""
).bindparams(name=name)
with self.Session() as sess:
query_result = sess.execute(query).fetchone()
if query_result is None:
raise CollectionNotFound(
"No collection found with requested name"
)
name, dimension = query_result
return Collection(
name,
dimension,
self,
)
def list_collections(self) -> List["Collection"]:
"""
List all vector collections.
Returns:
list[Collection]: A list of all collections.
"""
from r2r.vecs.collection import Collection
return Collection._list_collections(self)
def delete_collection(self, name: str) -> None:
"""
Delete a vector collection.
If no collection with requested name exists, does nothing.
Args:
name (str): The name of the collection.
Returns:
None
"""
from r2r.vecs.collection import Collection
Collection(name, -1, self)._drop()
return
def disconnect(self) -> None:
"""
Disconnect the client from the database.
Returns:
None
"""
self.engine.dispose()
logger.info("Disconnected from the database.")
return
def __enter__(self) -> "Client":
"""
Enable use of the 'with' statement.
Returns:
Client: The current instance of the Client.
"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Disconnect the client on exiting the 'with' statement context.
Args:
exc_type: The exception type, if any.
exc_val: The exception value, if any.
exc_tb: The traceback, if any.
Returns:
None
"""
self.disconnect()
return