aboutsummaryrefslogtreecommitdiff
path: root/R2R/tests/test_logging.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_logging.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/tests/test_logging.py')
-rwxr-xr-xR2R/tests/test_logging.py360
1 files changed, 360 insertions, 0 deletions
diff --git a/R2R/tests/test_logging.py b/R2R/tests/test_logging.py
new file mode 100755
index 00000000..cab5051d
--- /dev/null
+++ b/R2R/tests/test_logging.py
@@ -0,0 +1,360 @@
+import asyncio
+import logging
+import os
+import uuid
+
+import pytest
+
+from r2r import (
+ LocalKVLoggingProvider,
+ LoggingConfig,
+ PostgresKVLoggingProvider,
+ PostgresLoggingConfig,
+ RedisKVLoggingProvider,
+ RedisLoggingConfig,
+ generate_run_id,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+ loop = asyncio.get_event_loop_policy().new_event_loop()
+ yield loop
+ loop.close()
+ asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+ yield
+ for task in asyncio.all_tasks():
+ if task is not asyncio.current_task():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def local_provider():
+ """Fixture to create and tear down the LocalKVLoggingProvider with a unique database file."""
+ # Generate a unique file name for the SQLite database
+ unique_id = str(uuid.uuid4())
+ logging_path = f"test_{unique_id}.sqlite"
+
+ # Setup the LocalKVLoggingProvider with the unique file
+ provider = LocalKVLoggingProvider(LoggingConfig(logging_path=logging_path))
+
+ # Provide the setup provider to the test
+ yield provider
+
+ # Cleanup: Remove the SQLite file after test completes
+ provider.close()
+ if os.path.exists(logging_path):
+ os.remove(logging_path)
+
+
+@pytest.mark.asyncio
+async def test_local_logging(local_provider):
+ """Test logging and retrieving from the local logging provider."""
+ try:
+ run_id = generate_run_id()
+ await local_provider.init()
+ await local_provider.log(run_id, "key", "value")
+ logs = await local_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_multiple_log_entries(local_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+ await local_provider.init()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await local_provider.log(run_id, key, value)
+
+ logs = await local_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_log_retrieval_limit(local_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ await local_provider.init()
+
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await local_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await local_provider.get_logs(run_ids[0:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_specific_run_type_retrieval(local_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ await local_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await local_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await local_provider.log(run_id_0, "key_0", "value_0")
+ await local_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await local_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await local_provider.get_run_info(log_type_filter="search")
+ logs = await local_provider.get_logs([run.run_id for run in run_info])
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def postgres_provider():
+ """Fixture to create and tear down the PostgresKVLoggingProvider."""
+ log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
+ log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
+
+ provider = PostgresKVLoggingProvider(
+ PostgresLoggingConfig(
+ log_table=log_table, log_info_table=log_info_table
+ )
+ )
+ yield provider
+
+
+@pytest.mark.asyncio
+async def test_postgres_logging(postgres_provider):
+ """Test logging and retrieving from the postgres logging provider."""
+ try:
+ await postgres_provider.init()
+ run_id = generate_run_id()
+ await postgres_provider.log(run_id, "key", "value")
+ logs = await postgres_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_multiple_log_entries(postgres_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ await postgres_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await postgres_provider.log(run_id, key, value)
+
+ logs = await postgres_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_log_retrieval_limit(postgres_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ await postgres_provider.init()
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await postgres_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await postgres_provider.get_logs(run_ids[:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_postgres_specific_run_type_retrieval(postgres_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ await postgres_provider.init()
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await postgres_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await postgres_provider.log(run_id_0, "key_0", "value_0")
+ await postgres_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await postgres_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await postgres_provider.get_run_info(
+ log_type_filter="search"
+ )
+ logs = await postgres_provider.get_logs(
+ [run.run_id for run in run_info]
+ )
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.fixture(scope="function")
+def redis_provider():
+ """Fixture to create and tear down the RedisKVLoggingProvider."""
+ log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
+ log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
+
+ provider = RedisKVLoggingProvider(
+ RedisLoggingConfig(log_table=log_table, log_info_table=log_info_table)
+ )
+ yield provider
+ provider.close()
+
+
+@pytest.mark.asyncio
+async def test_redis_logging(redis_provider):
+ """Test logging and retrieving from the Redis logging provider."""
+ try:
+ run_id = generate_run_id()
+ await redis_provider.log(run_id, "key", "value")
+ logs = await redis_provider.get_logs([run_id])
+ assert len(logs) == 1
+ assert logs[0]["key"] == "key"
+ assert logs[0]["value"] == "value"
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_multiple_log_entries(redis_provider):
+ """Test logging multiple entries and retrieving them."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+ run_id_2 = generate_run_id()
+
+ entries = [
+ (run_id_0, "key_0", "value_0"),
+ (run_id_1, "key_1", "value_1"),
+ (run_id_2, "key_2", "value_2"),
+ ]
+ for run_id, key, value in entries:
+ await redis_provider.log(run_id, key, value)
+
+ logs = await redis_provider.get_logs([run_id_0, run_id_1, run_id_2])
+ assert len(logs) == 3
+
+ # Check that logs are returned in the correct order (most recent first if applicable)
+ for log in logs:
+ selected_entry = [
+ entry for entry in entries if entry[0] == log["log_id"]
+ ][0]
+ assert log["log_id"] == selected_entry[0]
+ assert log["key"] == selected_entry[1]
+ assert log["value"] == selected_entry[2]
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_log_retrieval_limit(redis_provider):
+ """Test the max_logs limit parameter works correctly."""
+ try:
+ run_ids = []
+ for i in range(10): # Add 10 entries
+ run_ids.append(generate_run_id())
+ await redis_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
+
+ logs = await redis_provider.get_logs(run_ids[0:5])
+ assert len(logs) == 5 # Ensure only 5 logs are returned
+ except asyncio.CancelledError:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_redis_specific_run_type_retrieval(redis_provider):
+ """Test retrieving logs for a specific run type works correctly."""
+ try:
+ run_id_0 = generate_run_id()
+ run_id_1 = generate_run_id()
+
+ await redis_provider.log(
+ run_id_0, "pipeline_type", "search", is_info_log=True
+ )
+ await redis_provider.log(run_id_0, "key_0", "value_0")
+ await redis_provider.log(
+ run_id_1, "pipeline_type", "rag", is_info_log=True
+ )
+ await redis_provider.log(run_id_1, "key_1", "value_1")
+
+ run_info = await redis_provider.get_run_info(log_type_filter="search")
+ logs = await redis_provider.get_logs([run.run_id for run in run_info])
+ assert len(logs) == 1
+ assert logs[0]["log_id"] == run_id_0
+ assert logs[0]["key"] == "key_0"
+ assert logs[0]["value"] == "value_0"
+ except asyncio.CancelledError:
+ pass