diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_vector_db.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/tests/test_vector_db.py')
-rwxr-xr-x | R2R/tests/test_vector_db.py | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/R2R/tests/test_vector_db.py b/R2R/tests/test_vector_db.py new file mode 100755 index 00000000..023145ce --- /dev/null +++ b/R2R/tests/test_vector_db.py @@ -0,0 +1,160 @@ +import random + +import pytest +from dotenv import load_dotenv + +from r2r import ( + Vector, + VectorDBConfig, + VectorDBProvider, + VectorEntry, + generate_id_from_label, +) +from r2r.providers.vector_dbs import PGVectorDB + +load_dotenv() + + +# Sample vector entries +def generate_random_vector_entry(id: str, dimension: int) -> VectorEntry: + vector = [random.random() for _ in range(dimension)] + metadata = {"key": f"value_{id}"} + return VectorEntry( + id=generate_id_from_label(id), vector=Vector(vector), metadata=metadata + ) + + +dimension = 3 +num_entries = 100 +sample_entries = [ + generate_random_vector_entry(f"id_{i}", dimension) + for i in range(num_entries) +] + + +# Fixture for PGVectorDB +@pytest.fixture +def pg_vector_db(): + random_collection_name = ( + f"test_collection_{random.randint(0, 1_000_000_000)}" + ) + config = VectorDBConfig.create( + provider="pgvector", vecs_collection=random_collection_name + ) + db = PGVectorDB(config) + db.initialize_collection(dimension=dimension) + yield db + # Teardown + db.vx.delete_collection( + db.config.extra_fields.get("vecs_collection", None) + ) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_get_metadatas(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + unique_metadatas = db.get_metadatas(metadata_fields=["key"]) + unique_values = set([ele["key"] for ele in unique_metadatas]) + assert len(unique_values) == num_entries + assert all(f"value_id_{i}" in unique_values for i in range(num_entries)) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_initialization(request, db_fixture): + db = request.getfixturevalue(db_fixture) + assert isinstance(db, VectorDBProvider) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_copy_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + results = db.search(query_vector=sample_entries[0].vector.data) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_db_upsert_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + results = db.search(query_vector=sample_entries[0].vector.data) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_imperfect_match(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + query_vector = [val + 0.1 for val in sample_entries[0].vector.data] + results = db.search(query_vector=query_vector) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].score < 1.0 + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_bulk_insert_and_search(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + query_vector = sample_entries[0].vector.data + results = db.search(query_vector=query_vector, limit=5) + assert len(results) == 5 + assert results[0].id == sample_entries[0].id + assert results[0].score == pytest.approx(1.0, rel=1e-3) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_search_with_filters(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + filtered_id = sample_entries[0].metadata["key"] + query_vector = sample_entries[0].vector.data + results = db.search( + query_vector=query_vector, filters={"key": filtered_id} + ) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].metadata["key"] == filtered_id + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_delete_by_metadata(request, db_fixture): + db = request.getfixturevalue(db_fixture) + for entry in sample_entries: + db.upsert(entry) + + key_to_delete = sample_entries[0].metadata["key"] + db.delete_by_metadata( + metadata_fields=["key"], metadata_values=[key_to_delete] + ) + + results = db.search(query_vector=sample_entries[0].vector.data) + assert all(result.metadata["key"] != key_to_delete for result in results) + + +@pytest.mark.parametrize("db_fixture", ["pg_vector_db"]) +def test_upsert(request, db_fixture): + db = request.getfixturevalue(db_fixture) + db.upsert(sample_entries[0]) + modified_entry = VectorEntry( + id=sample_entries[0].id, + vector=Vector([0.5, 0.5, 0.5]), + metadata={"key": "new_value"}, + ) + db.upsert(modified_entry) + + results = db.search(query_vector=[0.5, 0.5, 0.5]) + assert len(results) == 1 + assert results[0].id == sample_entries[0].id + assert results[0].metadata["key"] == "new_value" |