about summary refs log tree commit diff
path: root/R2R/tests/test_vector_db.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_vector_db.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/tests/test_vector_db.py')
-rwxr-xr-xR2R/tests/test_vector_db.py160
1 files changed, 160 insertions, 0 deletions
diff --git a/R2R/tests/test_vector_db.py b/R2R/tests/test_vector_db.py
new file mode 100755
index 00000000..023145ce
--- /dev/null
+++ b/R2R/tests/test_vector_db.py
@@ -0,0 +1,160 @@
+import random
+
+import pytest
+from dotenv import load_dotenv
+
+from r2r import (
+    Vector,
+    VectorDBConfig,
+    VectorDBProvider,
+    VectorEntry,
+    generate_id_from_label,
+)
+from r2r.providers.vector_dbs import PGVectorDB
+
+load_dotenv()
+
+
+# Sample vector entries
+def generate_random_vector_entry(id: str, dimension: int) -> VectorEntry:
+    vector = [random.random() for _ in range(dimension)]
+    metadata = {"key": f"value_{id}"}
+    return VectorEntry(
+        id=generate_id_from_label(id), vector=Vector(vector), metadata=metadata
+    )
+
+
+dimension = 3
+num_entries = 100
+sample_entries = [
+    generate_random_vector_entry(f"id_{i}", dimension)
+    for i in range(num_entries)
+]
+
+
+# Fixture for PGVectorDB
+@pytest.fixture
+def pg_vector_db():
+    random_collection_name = (
+        f"test_collection_{random.randint(0, 1_000_000_000)}"
+    )
+    config = VectorDBConfig.create(
+        provider="pgvector", vecs_collection=random_collection_name
+    )
+    db = PGVectorDB(config)
+    db.initialize_collection(dimension=dimension)
+    yield db
+    # Teardown
+    db.vx.delete_collection(
+        db.config.extra_fields.get("vecs_collection", None)
+    )
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_get_metadatas(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    for entry in sample_entries:
+        db.upsert(entry)
+
+    unique_metadatas = db.get_metadatas(metadata_fields=["key"])
+    unique_values = set([ele["key"] for ele in unique_metadatas])
+    assert len(unique_values) == num_entries
+    assert all(f"value_id_{i}" in unique_values for i in range(num_entries))
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_initialization(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    assert isinstance(db, VectorDBProvider)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_copy_and_search(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    db.upsert(sample_entries[0])
+    results = db.search(query_vector=sample_entries[0].vector.data)
+    assert len(results) == 1
+    assert results[0].id == sample_entries[0].id
+    assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_db_upsert_and_search(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    db.upsert(sample_entries[0])
+    results = db.search(query_vector=sample_entries[0].vector.data)
+    assert len(results) == 1
+    assert results[0].id == sample_entries[0].id
+    assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_imperfect_match(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    db.upsert(sample_entries[0])
+    query_vector = [val + 0.1 for val in sample_entries[0].vector.data]
+    results = db.search(query_vector=query_vector)
+    assert len(results) == 1
+    assert results[0].id == sample_entries[0].id
+    assert results[0].score < 1.0
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_bulk_insert_and_search(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    for entry in sample_entries:
+        db.upsert(entry)
+
+    query_vector = sample_entries[0].vector.data
+    results = db.search(query_vector=query_vector, limit=5)
+    assert len(results) == 5
+    assert results[0].id == sample_entries[0].id
+    assert results[0].score == pytest.approx(1.0, rel=1e-3)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_search_with_filters(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    for entry in sample_entries:
+        db.upsert(entry)
+
+    filtered_id = sample_entries[0].metadata["key"]
+    query_vector = sample_entries[0].vector.data
+    results = db.search(
+        query_vector=query_vector, filters={"key": filtered_id}
+    )
+    assert len(results) == 1
+    assert results[0].id == sample_entries[0].id
+    assert results[0].metadata["key"] == filtered_id
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_delete_by_metadata(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    for entry in sample_entries:
+        db.upsert(entry)
+
+    key_to_delete = sample_entries[0].metadata["key"]
+    db.delete_by_metadata(
+        metadata_fields=["key"], metadata_values=[key_to_delete]
+    )
+
+    results = db.search(query_vector=sample_entries[0].vector.data)
+    assert all(result.metadata["key"] != key_to_delete for result in results)
+
+
+@pytest.mark.parametrize("db_fixture", ["pg_vector_db"])
+def test_upsert(request, db_fixture):
+    db = request.getfixturevalue(db_fixture)
+    db.upsert(sample_entries[0])
+    modified_entry = VectorEntry(
+        id=sample_entries[0].id,
+        vector=Vector([0.5, 0.5, 0.5]),
+        metadata={"key": "new_value"},
+    )
+    db.upsert(modified_entry)
+
+    results = db.search(query_vector=[0.5, 0.5, 0.5])
+    assert len(results) == 1
+    assert results[0].id == sample_entries[0].id
+    assert results[0].metadata["key"] == "new_value"