about summary refs log tree commit diff
path: root/R2R/tests/test_embedding.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_embedding.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/tests/test_embedding.py')
-rwxr-xr-xR2R/tests/test_embedding.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/R2R/tests/test_embedding.py b/R2R/tests/test_embedding.py
new file mode 100755
index 00000000..7a3e760a
--- /dev/null
+++ b/R2R/tests/test_embedding.py
@@ -0,0 +1,162 @@
+import asyncio
+
+import pytest
+
+from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label
+from r2r.providers.embeddings import (
+    OpenAIEmbeddingProvider,
+    SentenceTransformerEmbeddingProvider,
+)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+    asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+    loop = asyncio.get_event_loop_policy().new_event_loop()
+    yield loop
+    loop.close()
+    asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+    yield
+    for task in asyncio.all_tasks():
+        if task is not asyncio.current_task():
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass
+
+
+@pytest.fixture
+def openai_provider():
+    config = EmbeddingConfig(
+        provider="openai",
+        base_model="text-embedding-3-small",
+        base_dimension=1536,
+    )
+    return OpenAIEmbeddingProvider(config)
+
+
+def test_openai_initialization(openai_provider):
+    assert isinstance(openai_provider, OpenAIEmbeddingProvider)
+    assert openai_provider.base_model == "text-embedding-3-small"
+    assert openai_provider.base_dimension == 1536
+
+
+def test_openai_invalid_provider_initialization():
+    config = EmbeddingConfig(provider="invalid_provider")
+    with pytest.raises(ValueError):
+        OpenAIEmbeddingProvider(config)
+
+
+def test_openai_get_embedding(openai_provider):
+    embedding = openai_provider.get_embedding("test text")
+    assert len(embedding) == 1536
+    assert isinstance(embedding, list)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embedding(openai_provider):
+    try:
+        embedding = await openai_provider.async_get_embedding("test text")
+        assert len(embedding) == 1536
+        assert isinstance(embedding, list)
+    except asyncio.CancelledError:
+        pass  # Task cancelled as expected
+
+
+def test_openai_get_embeddings(openai_provider):
+    embeddings = openai_provider.get_embeddings(["text1", "text2"])
+    assert len(embeddings) == 2
+    assert all(len(emb) == 1536 for emb in embeddings)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embeddings(openai_provider):
+    try:
+        embeddings = await openai_provider.async_get_embeddings(
+            ["text1", "text2"]
+        )
+        assert len(embeddings) == 2
+        assert all(len(emb) == 1536 for emb in embeddings)
+    except asyncio.CancelledError:
+        pass  # Task cancelled as expected
+
+
+def test_openai_tokenize_string(openai_provider):
+    tokens = openai_provider.tokenize_string(
+        "test text", "text-embedding-3-small"
+    )
+    assert isinstance(tokens, list)
+    assert all(isinstance(token, int) for token in tokens)
+
+
+@pytest.fixture
+def sentence_transformer_provider():
+    config = EmbeddingConfig(
+        provider="sentence-transformers",
+        base_model="mixedbread-ai/mxbai-embed-large-v1",
+        base_dimension=512,
+        rerank_model="jinaai/jina-reranker-v1-turbo-en",
+        rerank_dimension=384,
+    )
+    return SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_initialization(sentence_transformer_provider):
+    assert isinstance(
+        sentence_transformer_provider, SentenceTransformerEmbeddingProvider
+    )
+    assert sentence_transformer_provider.do_search
+    # assert sentence_transformer_provider.do_rerank
+
+
+def test_sentence_transformer_invalid_provider_initialization():
+    config = EmbeddingConfig(provider="invalid_provider")
+    with pytest.raises(ValueError):
+        SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_get_embedding(sentence_transformer_provider):
+    embedding = sentence_transformer_provider.get_embedding("test text")
+    assert len(embedding) == 512
+    assert isinstance(embedding, list)
+
+
+def test_sentence_transformer_get_embeddings(sentence_transformer_provider):
+    embeddings = sentence_transformer_provider.get_embeddings(
+        ["text1", "text2"]
+    )
+    assert len(embeddings) == 2
+    assert all(len(emb) == 512 for emb in embeddings)
+
+
+def test_sentence_transformer_rerank(sentence_transformer_provider):
+    results = [
+        VectorSearchResult(
+            id=generate_id_from_label("x"),
+            score=0.9,
+            metadata={"text": "doc1"},
+        ),
+        VectorSearchResult(
+            id=generate_id_from_label("y"),
+            score=0.8,
+            metadata={"text": "doc2"},
+        ),
+    ]
+    reranked_results = sentence_transformer_provider.rerank("query", results)
+    assert len(reranked_results) == 2
+    assert reranked_results[0].metadata["text"] == "doc1"
+    assert reranked_results[1].metadata["text"] == "doc2"
+
+
+def test_sentence_transformer_tokenize_string(sentence_transformer_provider):
+    with pytest.raises(ValueError):
+        sentence_transformer_provider.tokenize_string("test text")