import asyncio import pytest from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label from r2r.providers.embeddings import ( OpenAIEmbeddingProvider, SentenceTransformerEmbeddingProvider, ) @pytest.fixture(scope="session", autouse=True) def event_loop_policy(): asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @pytest.fixture(scope="function") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() asyncio.set_event_loop(None) @pytest.fixture(scope="session", autouse=True) async def cleanup_tasks(): yield for task in asyncio.all_tasks(): if task is not asyncio.current_task(): task.cancel() try: await task except asyncio.CancelledError: pass @pytest.fixture def openai_provider(): config = EmbeddingConfig( provider="openai", base_model="text-embedding-3-small", base_dimension=1536, ) return OpenAIEmbeddingProvider(config) def test_openai_initialization(openai_provider): assert isinstance(openai_provider, OpenAIEmbeddingProvider) assert openai_provider.base_model == "text-embedding-3-small" assert openai_provider.base_dimension == 1536 def test_openai_invalid_provider_initialization(): config = EmbeddingConfig(provider="invalid_provider") with pytest.raises(ValueError): OpenAIEmbeddingProvider(config) def test_openai_get_embedding(openai_provider): embedding = openai_provider.get_embedding("test text") assert len(embedding) == 1536 assert isinstance(embedding, list) @pytest.mark.asyncio async def test_openai_async_get_embedding(openai_provider): try: embedding = await openai_provider.async_get_embedding("test text") assert len(embedding) == 1536 assert isinstance(embedding, list) except asyncio.CancelledError: pass # Task cancelled as expected def test_openai_get_embeddings(openai_provider): embeddings = openai_provider.get_embeddings(["text1", "text2"]) assert len(embeddings) == 2 assert all(len(emb) == 1536 for emb in embeddings) @pytest.mark.asyncio async def test_openai_async_get_embeddings(openai_provider): try: embeddings = await openai_provider.async_get_embeddings( ["text1", "text2"] ) assert len(embeddings) == 2 assert all(len(emb) == 1536 for emb in embeddings) except asyncio.CancelledError: pass # Task cancelled as expected def test_openai_tokenize_string(openai_provider): tokens = openai_provider.tokenize_string( "test text", "text-embedding-3-small" ) assert isinstance(tokens, list) assert all(isinstance(token, int) for token in tokens) @pytest.fixture def sentence_transformer_provider(): config = EmbeddingConfig( provider="sentence-transformers", base_model="mixedbread-ai/mxbai-embed-large-v1", base_dimension=512, rerank_model="jinaai/jina-reranker-v1-turbo-en", rerank_dimension=384, ) return SentenceTransformerEmbeddingProvider(config) def test_sentence_transformer_initialization(sentence_transformer_provider): assert isinstance( sentence_transformer_provider, SentenceTransformerEmbeddingProvider ) assert sentence_transformer_provider.do_search # assert sentence_transformer_provider.do_rerank def test_sentence_transformer_invalid_provider_initialization(): config = EmbeddingConfig(provider="invalid_provider") with pytest.raises(ValueError): SentenceTransformerEmbeddingProvider(config) def test_sentence_transformer_get_embedding(sentence_transformer_provider): embedding = sentence_transformer_provider.get_embedding("test text") assert len(embedding) == 512 assert isinstance(embedding, list) def test_sentence_transformer_get_embeddings(sentence_transformer_provider): embeddings = sentence_transformer_provider.get_embeddings( ["text1", "text2"] ) assert len(embeddings) == 2 assert all(len(emb) == 512 for emb in embeddings) def test_sentence_transformer_rerank(sentence_transformer_provider): results = [ VectorSearchResult( id=generate_id_from_label("x"), score=0.9, metadata={"text": "doc1"}, ), VectorSearchResult( id=generate_id_from_label("y"), score=0.8, metadata={"text": "doc2"}, ), ] reranked_results = sentence_transformer_provider.rerank("query", results) assert len(reranked_results) == 2 assert reranked_results[0].metadata["text"] == "doc1" assert reranked_results[1].metadata["text"] == "doc2" def test_sentence_transformer_tokenize_string(sentence_transformer_provider): with pytest.raises(ValueError): sentence_transformer_provider.tokenize_string("test text")