aboutsummaryrefslogtreecommitdiff
path: root/R2R/tests/test_embedding.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/tests/test_embedding.py')
-rwxr-xr-xR2R/tests/test_embedding.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/R2R/tests/test_embedding.py b/R2R/tests/test_embedding.py
new file mode 100755
index 00000000..7a3e760a
--- /dev/null
+++ b/R2R/tests/test_embedding.py
@@ -0,0 +1,162 @@
+import asyncio
+
+import pytest
+
+from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label
+from r2r.providers.embeddings import (
+ OpenAIEmbeddingProvider,
+ SentenceTransformerEmbeddingProvider,
+)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture
+def openai_provider():
+ config = EmbeddingConfig(
+ provider="openai",
+ base_model="text-embedding-3-small",
+ base_dimension=1536,
+ )
+ return OpenAIEmbeddingProvider(config)
+
+
+def test_openai_initialization(openai_provider):
+ assert isinstance(openai_provider, OpenAIEmbeddingProvider)
+ assert openai_provider.base_model == "text-embedding-3-small"
+ assert openai_provider.base_dimension == 1536
+
+
+def test_openai_invalid_provider_initialization():
+ config = EmbeddingConfig(provider="invalid_provider")
+ with pytest.raises(ValueError):
+ OpenAIEmbeddingProvider(config)
+
+
+def test_openai_get_embedding(openai_provider):
+ embedding = openai_provider.get_embedding("test text")
+ assert len(embedding) == 1536
+ assert isinstance(embedding, list)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embedding(openai_provider):
+ try:
+ embedding = await openai_provider.async_get_embedding("test text")
+ assert len(embedding) == 1536
+ assert isinstance(embedding, list)
+ except asyncio.CancelledError:
+ pass # Task cancelled as expected
+
+
+def test_openai_get_embeddings(openai_provider):
+ embeddings = openai_provider.get_embeddings(["text1", "text2"])
+ assert len(embeddings) == 2
+ assert all(len(emb) == 1536 for emb in embeddings)
+
+
+@pytest.mark.asyncio
+async def test_openai_async_get_embeddings(openai_provider):
+ try:
+ embeddings = await openai_provider.async_get_embeddings(
+ ["text1", "text2"]
+ )
+ assert len(embeddings) == 2
+ assert all(len(emb) == 1536 for emb in embeddings)
+ except asyncio.CancelledError:
+ pass # Task cancelled as expected
+
+
+def test_openai_tokenize_string(openai_provider):
+ tokens = openai_provider.tokenize_string(
+ "test text", "text-embedding-3-small"
+ )
+ assert isinstance(tokens, list)
+ assert all(isinstance(token, int) for token in tokens)
+
+
+@pytest.fixture
+def sentence_transformer_provider():
+ config = EmbeddingConfig(
+ provider="sentence-transformers",
+ base_model="mixedbread-ai/mxbai-embed-large-v1",
+ base_dimension=512,
+ rerank_model="jinaai/jina-reranker-v1-turbo-en",
+ rerank_dimension=384,
+ )
+ return SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_initialization(sentence_transformer_provider):
+ assert isinstance(
+ sentence_transformer_provider, SentenceTransformerEmbeddingProvider
+ )
+ assert sentence_transformer_provider.do_search
+ # assert sentence_transformer_provider.do_rerank
+
+
+def test_sentence_transformer_invalid_provider_initialization():
+ config = EmbeddingConfig(provider="invalid_provider")
+ with pytest.raises(ValueError):
+ SentenceTransformerEmbeddingProvider(config)
+
+
+def test_sentence_transformer_get_embedding(sentence_transformer_provider):
+ embedding = sentence_transformer_provider.get_embedding("test text")
+ assert len(embedding) == 512
+ assert isinstance(embedding, list)
+
+
+def test_sentence_transformer_get_embeddings(sentence_transformer_provider):
+ embeddings = sentence_transformer_provider.get_embeddings(
+ ["text1", "text2"]
+ )
+ assert len(embeddings) == 2
+ assert all(len(emb) == 512 for emb in embeddings)
+
+
+def test_sentence_transformer_rerank(sentence_transformer_provider):
+ results = [
+ VectorSearchResult(
+ id=generate_id_from_label("x"),
+ score=0.9,
+ metadata={"text": "doc1"},
+ ),
+ VectorSearchResult(
+ id=generate_id_from_label("y"),
+ score=0.8,
+ metadata={"text": "doc2"},
+ ),
+ ]
+ reranked_results = sentence_transformer_provider.rerank("query", results)
+ assert len(reranked_results) == 2
+ assert reranked_results[0].metadata["text"] == "doc1"
+ assert reranked_results[1].metadata["text"] == "doc2"
+
+
+def test_sentence_transformer_tokenize_string(sentence_transformer_provider):
+ with pytest.raises(ValueError):
+ sentence_transformer_provider.tokenize_string("test text")