aboutsummaryrefslogtreecommitdiff
import asyncio
import uuid

import pytest

from r2r import (
    AsyncPipe,
    AsyncState,
    Prompt,
    Vector,
    VectorEntry,
    VectorSearchRequest,
    VectorSearchResult,
    VectorType,
    generate_id_from_label,
)


@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
    asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())


@pytest.fixture(scope="function")
def event_loop():
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()
    asyncio.set_event_loop(None)


@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
    yield
    for task in asyncio.all_tasks():
        if task is not asyncio.current_task():
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                pass


@pytest.mark.asyncio
async def test_async_state_update_and_get():
    state = AsyncState()
    outer_key = "test_key"
    values = {"inner_key": "value"}
    await state.update(outer_key, values)
    result = await state.get(outer_key, "inner_key")
    assert result == "value"


@pytest.mark.asyncio
async def test_async_state_delete():
    state = AsyncState()
    outer_key = "test_key"
    values = {"inner_key": "value"}
    await state.update(outer_key, values)
    await state.delete(outer_key, "inner_key")
    result = await state.get(outer_key, "inner_key")
    assert result == {}, "Expect empty result after deletion"


class MockAsyncPipe(AsyncPipe):
    async def _run_logic(self, input, state, run_id, *args, **kwargs):
        yield "processed"


@pytest.mark.asyncio
async def test_async_pipe_run():
    pipe = MockAsyncPipe()

    async def list_to_generator(lst):
        for item in lst:
            yield item

    input = pipe.Input(message=list_to_generator(["test"]))
    state = AsyncState()
    try:
        async_generator = await pipe.run(input, state)
        results = [result async for result in async_generator]
        assert results == ["processed"]
    except asyncio.CancelledError:
        pass  # Task cancelled as expected


def test_prompt_initialization_and_formatting():
    prompt = Prompt(
        name="greet", template="Hello, {name}!", input_types={"name": "str"}
    )
    formatted = prompt.format_prompt({"name": "Alice"})
    assert formatted == "Hello, Alice!"


def test_prompt_missing_input():
    prompt = Prompt(
        name="greet", template="Hello, {name}!", input_types={"name": "str"}
    )
    with pytest.raises(ValueError):
        prompt.format_prompt({})


def test_prompt_invalid_input_type():
    prompt = Prompt(
        name="greet", template="Hello, {name}!", input_types={"name": "int"}
    )
    with pytest.raises(TypeError):
        prompt.format_prompt({"name": "Alice"})


def test_search_request_with_optional_filters():
    request = VectorSearchRequest(
        query="test", limit=10, filters={"category": "books"}
    )
    assert request.query == "test"
    assert request.limit == 10
    assert request.filters == {"category": "books"}


def test_search_result_to_string():
    result = VectorSearchResult(
        id=generate_id_from_label("1"),
        score=9.5,
        metadata={"author": "John Doe"},
    )
    result_str = str(result)
    assert (
        result_str
        == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})"
    )


def test_search_result_repr():
    result = VectorSearchResult(
        id=generate_id_from_label("1"),
        score=9.5,
        metadata={"author": "John Doe"},
    )
    assert (
        repr(result)
        == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})"
    )


def test_vector_fixed_length_validation():
    with pytest.raises(ValueError):
        Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=3)


def test_vector_entry_serialization():
    vector = Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=2)
    entry_id = uuid.uuid4()
    entry = VectorEntry(
        id=entry_id, vector=vector, metadata={"key": uuid.uuid4()}
    )
    serializable = entry.to_serializable()
    assert serializable["id"] == str(entry_id)
    assert serializable["vector"] == [1.0, 2.0]
    assert isinstance(
        serializable["metadata"]["key"], str
    )  # Check UUID conversion to string