import asyncio
import pytest
from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label
from r2r.providers.embeddings import (
OpenAIEmbeddingProvider,
SentenceTransformerEmbeddingProvider,
)
@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
@pytest.fixture(scope="function")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
asyncio.set_event_loop(None)
@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
yield
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.fixture
def openai_provider():
config = EmbeddingConfig(
provider="openai",
base_model="text-embedding-3-small",
base_dimension=1536,
)
return OpenAIEmbeddingProvider(config)
def test_openai_initialization(openai_provider):
assert isinstance(openai_provider, OpenAIEmbeddingProvider)
assert openai_provider.base_model == "text-embedding-3-small"
assert openai_provider.base_dimension == 1536
def test_openai_invalid_provider_initialization():
config = EmbeddingConfig(provider="invalid_provider")
with pytest.raises(ValueError):
OpenAIEmbeddingProvider(config)
def test_openai_get_embedding(openai_provider):
embedding = openai_provider.get_embedding("test text")
assert len(embedding) == 1536
assert isinstance(embedding, list)
@pytest.mark.asyncio
async def test_openai_async_get_embedding(openai_provider):
try:
embedding = await openai_provider.async_get_embedding("test text")
assert len(embedding) == 1536
assert isinstance(embedding, list)
except asyncio.CancelledError:
pass # Task cancelled as expected
def test_openai_get_embeddings(openai_provider):
embeddings = openai_provider.get_embeddings(["text1", "text2"])
assert len(embeddings) == 2
assert all(len(emb) == 1536 for emb in embeddings)
@pytest.mark.asyncio
async def test_openai_async_get_embeddings(openai_provider):
try:
embeddings = await openai_provider.async_get_embeddings(
["text1", "text2"]
)
assert len(embeddings) == 2
assert all(len(emb) == 1536 for emb in embeddings)
except asyncio.CancelledError:
pass # Task cancelled as expected
def test_openai_tokenize_string(openai_provider):
tokens = openai_provider.tokenize_string(
"test text", "text-embedding-3-small"
)
assert isinstance(tokens, list)
assert all(isinstance(token, int) for token in tokens)
@pytest.fixture
def sentence_transformer_provider():
config = EmbeddingConfig(
provider="sentence-transformers",
base_model="mixedbread-ai/mxbai-embed-large-v1",
base_dimension=512,
rerank_model="jinaai/jina-reranker-v1-turbo-en",
rerank_dimension=384,
)
return SentenceTransformerEmbeddingProvider(config)
def test_sentence_transformer_initialization(sentence_transformer_provider):
assert isinstance(
sentence_transformer_provider, SentenceTransformerEmbeddingProvider
)
assert sentence_transformer_provider.do_search
# assert sentence_transformer_provider.do_rerank
def test_sentence_transformer_invalid_provider_initialization():
config = EmbeddingConfig(provider="invalid_provider")
with pytest.raises(ValueError):
SentenceTransformerEmbeddingProvider(config)
def test_sentence_transformer_get_embedding(sentence_transformer_provider):
embedding = sentence_transformer_provider.get_embedding("test text")
assert len(embedding) == 512
assert isinstance(embedding, list)
def test_sentence_transformer_get_embeddings(sentence_transformer_provider):
embeddings = sentence_transformer_provider.get_embeddings(
["text1", "text2"]
)
assert len(embeddings) == 2
assert all(len(emb) == 512 for emb in embeddings)
def test_sentence_transformer_rerank(sentence_transformer_provider):
results = [
VectorSearchResult(
id=generate_id_from_label("x"),
score=0.9,
metadata={"text": "doc1"},
),
VectorSearchResult(
id=generate_id_from_label("y"),
score=0.8,
metadata={"text": "doc2"},
),
]
reranked_results = sentence_transformer_provider.rerank("query", results)
assert len(reranked_results) == 2
assert reranked_results[0].metadata["text"] == "doc1"
assert reranked_results[1].metadata["text"] == "doc2"
def test_sentence_transformer_tokenize_string(sentence_transformer_provider):
with pytest.raises(ValueError):
sentence_transformer_provider.tokenize_string("test text")