import asyncio import uuid import pytest from r2r import ( AsyncPipe, AsyncState, Prompt, Vector, VectorEntry, VectorSearchRequest, VectorSearchResult, VectorType, generate_id_from_label, ) @pytest.fixture(scope="session", autouse=True) def event_loop_policy(): asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @pytest.fixture(scope="function") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() asyncio.set_event_loop(None) @pytest.fixture(scope="session", autouse=True) async def cleanup_tasks(): yield for task in asyncio.all_tasks(): if task is not asyncio.current_task(): task.cancel() try: await task except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_async_state_update_and_get(): state = AsyncState() outer_key = "test_key" values = {"inner_key": "value"} await state.update(outer_key, values) result = await state.get(outer_key, "inner_key") assert result == "value" @pytest.mark.asyncio async def test_async_state_delete(): state = AsyncState() outer_key = "test_key" values = {"inner_key": "value"} await state.update(outer_key, values) await state.delete(outer_key, "inner_key") result = await state.get(outer_key, "inner_key") assert result == {}, "Expect empty result after deletion" class MockAsyncPipe(AsyncPipe): async def _run_logic(self, input, state, run_id, *args, **kwargs): yield "processed" @pytest.mark.asyncio async def test_async_pipe_run(): pipe = MockAsyncPipe() async def list_to_generator(lst): for item in lst: yield item input = pipe.Input(message=list_to_generator(["test"])) state = AsyncState() try: async_generator = await pipe.run(input, state) results = [result async for result in async_generator] assert results == ["processed"] except asyncio.CancelledError: pass # Task cancelled as expected def test_prompt_initialization_and_formatting(): prompt = Prompt( name="greet", template="Hello, {name}!", input_types={"name": "str"} ) formatted = prompt.format_prompt({"name": "Alice"}) assert formatted == "Hello, Alice!" def test_prompt_missing_input(): prompt = Prompt( name="greet", template="Hello, {name}!", input_types={"name": "str"} ) with pytest.raises(ValueError): prompt.format_prompt({}) def test_prompt_invalid_input_type(): prompt = Prompt( name="greet", template="Hello, {name}!", input_types={"name": "int"} ) with pytest.raises(TypeError): prompt.format_prompt({"name": "Alice"}) def test_search_request_with_optional_filters(): request = VectorSearchRequest( query="test", limit=10, filters={"category": "books"} ) assert request.query == "test" assert request.limit == 10 assert request.filters == {"category": "books"} def test_search_result_to_string(): result = VectorSearchResult( id=generate_id_from_label("1"), score=9.5, metadata={"author": "John Doe"}, ) result_str = str(result) assert ( result_str == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" ) def test_search_result_repr(): result = VectorSearchResult( id=generate_id_from_label("1"), score=9.5, metadata={"author": "John Doe"}, ) assert ( repr(result) == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" ) def test_vector_fixed_length_validation(): with pytest.raises(ValueError): Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=3) def test_vector_entry_serialization(): vector = Vector(data=[1.0, 2.0], type=VectorType.FIXED, length=2) entry_id = uuid.uuid4() entry = VectorEntry( id=entry_id, vector=vector, metadata={"key": uuid.uuid4()} ) serializable = entry.to_serializable() assert serializable["id"] == str(entry_id) assert serializable["vector"] == [1.0, 2.0] assert isinstance( serializable["metadata"]["key"], str ) # Check UUID conversion to string