import asyncio import logging import os import uuid import pytest from r2r import ( LocalKVLoggingProvider, LoggingConfig, PostgresKVLoggingProvider, PostgresLoggingConfig, RedisKVLoggingProvider, RedisLoggingConfig, generate_run_id, ) logger = logging.getLogger(__name__) @pytest.fixture(scope="session", autouse=True) def event_loop_policy(): asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @pytest.fixture(scope="function") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() asyncio.set_event_loop(None) @pytest.fixture(scope="session", autouse=True) async def cleanup_tasks(): yield for task in asyncio.all_tasks(): if task is not asyncio.current_task(): task.cancel() try: await task except asyncio.CancelledError: pass @pytest.fixture(scope="function") def local_provider(): """Fixture to create and tear down the LocalKVLoggingProvider with a unique database file.""" # Generate a unique file name for the SQLite database unique_id = str(uuid.uuid4()) logging_path = f"test_{unique_id}.sqlite" # Setup the LocalKVLoggingProvider with the unique file provider = LocalKVLoggingProvider(LoggingConfig(logging_path=logging_path)) # Provide the setup provider to the test yield provider # Cleanup: Remove the SQLite file after test completes provider.close() if os.path.exists(logging_path): os.remove(logging_path) @pytest.mark.asyncio async def test_local_logging(local_provider): """Test logging and retrieving from the local logging provider.""" try: run_id = generate_run_id() await local_provider.init() await local_provider.log(run_id, "key", "value") logs = await local_provider.get_logs([run_id]) assert len(logs) == 1 assert logs[0]["key"] == "key" assert logs[0]["value"] == "value" except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_multiple_log_entries(local_provider): """Test logging multiple entries and retrieving them.""" try: run_id_0 = generate_run_id() run_id_1 = generate_run_id() run_id_2 = generate_run_id() await local_provider.init() entries = [ (run_id_0, "key_0", "value_0"), (run_id_1, "key_1", "value_1"), (run_id_2, "key_2", "value_2"), ] for run_id, key, value in entries: await local_provider.log(run_id, key, value) logs = await local_provider.get_logs([run_id_0, run_id_1, run_id_2]) assert len(logs) == 3 # Check that logs are returned in the correct order (most recent first if applicable) for log in logs: selected_entry = [ entry for entry in entries if entry[0] == log["log_id"] ][0] assert log["log_id"] == selected_entry[0] assert log["key"] == selected_entry[1] assert log["value"] == selected_entry[2] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_log_retrieval_limit(local_provider): """Test the max_logs limit parameter works correctly.""" try: await local_provider.init() run_ids = [] for i in range(10): # Add 10 entries run_ids.append(generate_run_id()) await local_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") logs = await local_provider.get_logs(run_ids[0:5]) assert len(logs) == 5 # Ensure only 5 logs are returned except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_specific_run_type_retrieval(local_provider): """Test retrieving logs for a specific run type works correctly.""" try: await local_provider.init() run_id_0 = generate_run_id() run_id_1 = generate_run_id() await local_provider.log( run_id_0, "pipeline_type", "search", is_info_log=True ) await local_provider.log(run_id_0, "key_0", "value_0") await local_provider.log( run_id_1, "pipeline_type", "rag", is_info_log=True ) await local_provider.log(run_id_1, "key_1", "value_1") run_info = await local_provider.get_run_info(log_type_filter="search") logs = await local_provider.get_logs([run.run_id for run in run_info]) assert len(logs) == 1 assert logs[0]["log_id"] == run_id_0 assert logs[0]["key"] == "key_0" assert logs[0]["value"] == "value_0" except asyncio.CancelledError: pass @pytest.fixture(scope="function") def postgres_provider(): """Fixture to create and tear down the PostgresKVLoggingProvider.""" log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}" log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}" provider = PostgresKVLoggingProvider( PostgresLoggingConfig( log_table=log_table, log_info_table=log_info_table ) ) yield provider @pytest.mark.asyncio async def test_postgres_logging(postgres_provider): """Test logging and retrieving from the postgres logging provider.""" try: await postgres_provider.init() run_id = generate_run_id() await postgres_provider.log(run_id, "key", "value") logs = await postgres_provider.get_logs([run_id]) assert len(logs) == 1 assert logs[0]["key"] == "key" assert logs[0]["value"] == "value" except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_postgres_multiple_log_entries(postgres_provider): """Test logging multiple entries and retrieving them.""" try: await postgres_provider.init() run_id_0 = generate_run_id() run_id_1 = generate_run_id() run_id_2 = generate_run_id() entries = [ (run_id_0, "key_0", "value_0"), (run_id_1, "key_1", "value_1"), (run_id_2, "key_2", "value_2"), ] for run_id, key, value in entries: await postgres_provider.log(run_id, key, value) logs = await postgres_provider.get_logs([run_id_0, run_id_1, run_id_2]) assert len(logs) == 3 # Check that logs are returned in the correct order (most recent first if applicable) for log in logs: selected_entry = [ entry for entry in entries if entry[0] == log["log_id"] ][0] assert log["log_id"] == selected_entry[0] assert log["key"] == selected_entry[1] assert log["value"] == selected_entry[2] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_postgres_log_retrieval_limit(postgres_provider): """Test the max_logs limit parameter works correctly.""" try: await postgres_provider.init() run_ids = [] for i in range(10): # Add 10 entries run_ids.append(generate_run_id()) await postgres_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") logs = await postgres_provider.get_logs(run_ids[:5]) assert len(logs) == 5 # Ensure only 5 logs are returned except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_postgres_specific_run_type_retrieval(postgres_provider): """Test retrieving logs for a specific run type works correctly.""" try: await postgres_provider.init() run_id_0 = generate_run_id() run_id_1 = generate_run_id() await postgres_provider.log( run_id_0, "pipeline_type", "search", is_info_log=True ) await postgres_provider.log(run_id_0, "key_0", "value_0") await postgres_provider.log( run_id_1, "pipeline_type", "rag", is_info_log=True ) await postgres_provider.log(run_id_1, "key_1", "value_1") run_info = await postgres_provider.get_run_info( log_type_filter="search" ) logs = await postgres_provider.get_logs( [run.run_id for run in run_info] ) assert len(logs) == 1 assert logs[0]["log_id"] == run_id_0 assert logs[0]["key"] == "key_0" assert logs[0]["value"] == "value_0" except asyncio.CancelledError: pass @pytest.fixture(scope="function") def redis_provider(): """Fixture to create and tear down the RedisKVLoggingProvider.""" log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}" log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}" provider = RedisKVLoggingProvider( RedisLoggingConfig(log_table=log_table, log_info_table=log_info_table) ) yield provider provider.close() @pytest.mark.asyncio async def test_redis_logging(redis_provider): """Test logging and retrieving from the Redis logging provider.""" try: run_id = generate_run_id() await redis_provider.log(run_id, "key", "value") logs = await redis_provider.get_logs([run_id]) assert len(logs) == 1 assert logs[0]["key"] == "key" assert logs[0]["value"] == "value" except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_redis_multiple_log_entries(redis_provider): """Test logging multiple entries and retrieving them.""" try: run_id_0 = generate_run_id() run_id_1 = generate_run_id() run_id_2 = generate_run_id() entries = [ (run_id_0, "key_0", "value_0"), (run_id_1, "key_1", "value_1"), (run_id_2, "key_2", "value_2"), ] for run_id, key, value in entries: await redis_provider.log(run_id, key, value) logs = await redis_provider.get_logs([run_id_0, run_id_1, run_id_2]) assert len(logs) == 3 # Check that logs are returned in the correct order (most recent first if applicable) for log in logs: selected_entry = [ entry for entry in entries if entry[0] == log["log_id"] ][0] assert log["log_id"] == selected_entry[0] assert log["key"] == selected_entry[1] assert log["value"] == selected_entry[2] except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_redis_log_retrieval_limit(redis_provider): """Test the max_logs limit parameter works correctly.""" try: run_ids = [] for i in range(10): # Add 10 entries run_ids.append(generate_run_id()) await redis_provider.log(run_ids[-1], f"key_{i}", f"value_{i}") logs = await redis_provider.get_logs(run_ids[0:5]) assert len(logs) == 5 # Ensure only 5 logs are returned except asyncio.CancelledError: pass @pytest.mark.asyncio async def test_redis_specific_run_type_retrieval(redis_provider): """Test retrieving logs for a specific run type works correctly.""" try: run_id_0 = generate_run_id() run_id_1 = generate_run_id() await redis_provider.log( run_id_0, "pipeline_type", "search", is_info_log=True ) await redis_provider.log(run_id_0, "key_0", "value_0") await redis_provider.log( run_id_1, "pipeline_type", "rag", is_info_log=True ) await redis_provider.log(run_id_1, "key_1", "value_1") run_info = await redis_provider.get_run_info(log_type_filter="search") logs = await redis_provider.get_logs([run.run_id for run in run_info]) assert len(logs) == 1 assert logs[0]["log_id"] == run_id_0 assert logs[0]["key"] == "key_0" assert logs[0]["value"] == "value_0" except asyncio.CancelledError: pass