diff options
Diffstat (limited to 'R2R/tests/test_embedding.py')
-rwxr-xr-x | R2R/tests/test_embedding.py | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/R2R/tests/test_embedding.py b/R2R/tests/test_embedding.py new file mode 100755 index 00000000..7a3e760a --- /dev/null +++ b/R2R/tests/test_embedding.py @@ -0,0 +1,162 @@ +import asyncio + +import pytest + +from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label +from r2r.providers.embeddings import ( + OpenAIEmbeddingProvider, + SentenceTransformerEmbeddingProvider, +) + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture +def openai_provider(): + config = EmbeddingConfig( + provider="openai", + base_model="text-embedding-3-small", + base_dimension=1536, + ) + return OpenAIEmbeddingProvider(config) + + +def test_openai_initialization(openai_provider): + assert isinstance(openai_provider, OpenAIEmbeddingProvider) + assert openai_provider.base_model == "text-embedding-3-small" + assert openai_provider.base_dimension == 1536 + + +def test_openai_invalid_provider_initialization(): + config = EmbeddingConfig(provider="invalid_provider") + with pytest.raises(ValueError): + OpenAIEmbeddingProvider(config) + + +def test_openai_get_embedding(openai_provider): + embedding = openai_provider.get_embedding("test text") + assert len(embedding) == 1536 + assert isinstance(embedding, list) + + +@pytest.mark.asyncio +async def test_openai_async_get_embedding(openai_provider): + try: + embedding = await openai_provider.async_get_embedding("test text") + assert len(embedding) == 1536 + assert isinstance(embedding, list) + except asyncio.CancelledError: + pass # Task cancelled as expected + + +def test_openai_get_embeddings(openai_provider): + embeddings = openai_provider.get_embeddings(["text1", "text2"]) + assert len(embeddings) == 2 + assert all(len(emb) == 1536 for emb in embeddings) + + +@pytest.mark.asyncio +async def test_openai_async_get_embeddings(openai_provider): + try: + embeddings = await openai_provider.async_get_embeddings( + ["text1", "text2"] + ) + assert len(embeddings) == 2 + assert all(len(emb) == 1536 for emb in embeddings) + except asyncio.CancelledError: + pass # Task cancelled as expected + + +def test_openai_tokenize_string(openai_provider): + tokens = openai_provider.tokenize_string( + "test text", "text-embedding-3-small" + ) + assert isinstance(tokens, list) + assert all(isinstance(token, int) for token in tokens) + + +@pytest.fixture +def sentence_transformer_provider(): + config = EmbeddingConfig( + provider="sentence-transformers", + base_model="mixedbread-ai/mxbai-embed-large-v1", + base_dimension=512, + rerank_model="jinaai/jina-reranker-v1-turbo-en", + rerank_dimension=384, + ) + return SentenceTransformerEmbeddingProvider(config) + + +def test_sentence_transformer_initialization(sentence_transformer_provider): + assert isinstance( + sentence_transformer_provider, SentenceTransformerEmbeddingProvider + ) + assert sentence_transformer_provider.do_search + # assert sentence_transformer_provider.do_rerank + + +def test_sentence_transformer_invalid_provider_initialization(): + config = EmbeddingConfig(provider="invalid_provider") + with pytest.raises(ValueError): + SentenceTransformerEmbeddingProvider(config) + + +def test_sentence_transformer_get_embedding(sentence_transformer_provider): + embedding = sentence_transformer_provider.get_embedding("test text") + assert len(embedding) == 512 + assert isinstance(embedding, list) + + +def test_sentence_transformer_get_embeddings(sentence_transformer_provider): + embeddings = sentence_transformer_provider.get_embeddings( + ["text1", "text2"] + ) + assert len(embeddings) == 2 + assert all(len(emb) == 512 for emb in embeddings) + + +def test_sentence_transformer_rerank(sentence_transformer_provider): + results = [ + VectorSearchResult( + id=generate_id_from_label("x"), + score=0.9, + metadata={"text": "doc1"}, + ), + VectorSearchResult( + id=generate_id_from_label("y"), + score=0.8, + metadata={"text": "doc2"}, + ), + ] + reranked_results = sentence_transformer_provider.rerank("query", results) + assert len(reranked_results) == 2 + assert reranked_results[0].metadata["text"] == "doc1" + assert reranked_results[1].metadata["text"] == "doc2" + + +def test_sentence_transformer_tokenize_string(sentence_transformer_provider): + with pytest.raises(ValueError): + sentence_transformer_provider.tokenize_string("test text") |