aboutsummaryrefslogtreecommitdiff
path: root/R2R/tests/test_vector_db.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_vector_db.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/tests/test_vector_db.py')
-rwxr-xr-xR2R/tests/test_vector_db.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/R2R/tests/test_vector_db.py b/R2R/tests/test_vector_db.py
new file mode 100755
index 00000000..023145ce
--- /dev/null
+++ b/R2R/tests/test_vector_db.py
@@ -0,0 +1,160 @@
+import random
+
+import pytest
+from dotenv import load_dotenv
+
+from r2r import (
+ Vector,
+ VectorDBConfig,
+ VectorDBProvider,
+ VectorEntry,
+ generate_id_from_label,
+)
+from r2r.providers.vector_dbs import PGVectorDB
+
+load_dotenv()
+
+
+# Sample vector entries
+def generate_random_vector_entry(id: str, dimension: int) -> VectorEntry:
+ vector = [random.random() for _ in range(dimension)]
+ metadata = {"key": f"value_{id}"}
+ return VectorEntry(
+ id=generate_id_from_label(id), vector=Vector(vector), metadata=metadata
+ )
+
+
+dimension = 3
+num_entries = 100
+sample_entries = [
+ generate_random_vector_entry(f"id_{i}", dimension)
+ for i in range(num_entries)
+]
+
+
+# Fixture for PGVectorDB
+@pytest.fixture
+def pg_vector_db():
+ random_collection_name = (
+ f"test_collection_{random.randint(0, 1_000_000_000)}"
+ )
+ config = VectorDBConfig.create(
+ provider="pgvector", vecs_collection=random_collection_name
+ )
+ db = PGVectorDB(config)
+ db.initialize_collection(dimension=dimension)
+ yield db
+ # Teardown
+ db.vx.delete_collection(
+ db.config.extra_fields.get("vecs_collection", None)
+ )
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_get_metadatas(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ unique_metadatas = db.get_metadatas(metadata_fields=["key"])
+ unique_values = set([ele["key"] for ele in unique_metadatas])
+ assert len(unique_values) == num_entries
+ assert all(f"value_id_{i}" in unique_values for i in range(num_entries))
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_initialization(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ assert isinstance(db, VectorDBProvider)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_copy_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_upsert_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_imperfect_match(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ query_vector = [val + 0.1 for val in sample_entries[0].vector.data]
+ results = db.search(query_vector=query_vector)
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score < 1.0
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_bulk_insert_and_search(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ query_vector = sample_entries[0].vector.data
+ results = db.search(query_vector=query_vector, limit=5)
+ assert len(results) == 5
+ assert results[0].id == sample_entries[0].id
+ assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_search_with_filters(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ filtered_id = sample_entries[0].metadata["key"]
+ query_vector = sample_entries[0].vector.data
+ results = db.search(
+ query_vector=query_vector, filters={"key": filtered_id}
+ )
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].metadata["key"] == filtered_id
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_delete_by_metadata(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ for entry in sample_entries:
+ db.upsert(entry)
+
+ key_to_delete = sample_entries[0].metadata["key"]
+ db.delete_by_metadata(
+ metadata_fields=["key"], metadata_values=[key_to_delete]
+ )
+
+ results = db.search(query_vector=sample_entries[0].vector.data)
+ assert all(result.metadata["key"] != key_to_delete for result in results)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_upsert(request, db_fixture):
+ db = request.getfixturevalue(db_fixture)
+ db.upsert(sample_entries[0])
+ modified_entry = VectorEntry(
+ id=sample_entries[0].id,
+ vector=Vector([0.5, 0.5, 0.5]),
+ metadata={"key": "new_value"},
+ )
+ db.upsert(modified_entry)
+
+ results = db.search(query_vector=[0.5, 0.5, 0.5])
+ assert len(results) == 1
+ assert results[0].id == sample_entries[0].id
+ assert results[0].metadata["key"] == "new_value"