import asyncio
import logging
import os
import uuid
import pytest
from r2r import (
LocalKVLoggingProvider,
LoggingConfig,
PostgresKVLoggingProvider,
PostgresLoggingConfig,
RedisKVLoggingProvider,
RedisLoggingConfig,
generate_run_id,
)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
@pytest.fixture(scope="function")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
asyncio.set_event_loop(None)
@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
yield
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.fixture(scope="function")
def local_provider():
"""Fixture to create and tear down the LocalKVLoggingProvider with a unique database file."""
# Generate a unique file name for the SQLite database
unique_id = str(uuid.uuid4())
logging_path = f"test_{unique_id}.sqlite"
# Setup the LocalKVLoggingProvider with the unique file
provider = LocalKVLoggingProvider(LoggingConfig(logging_path=logging_path))
# Provide the setup provider to the test
yield provider
# Cleanup: Remove the SQLite file after test completes
provider.close()
if os.path.exists(logging_path):
os.remove(logging_path)
@pytest.mark.asyncio
async def test_local_logging(local_provider):
"""Test logging and retrieving from the local logging provider."""
try:
run_id = generate_run_id()
await local_provider.init()
await local_provider.log(run_id, "key", "value")
logs = await local_provider.get_logs([run_id])
assert len(logs) == 1
assert logs[0]["key"] == "key"
assert logs[0]["value"] == "value"
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_multiple_log_entries(local_provider):
"""Test logging multiple entries and retrieving them."""
try:
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
run_id_2 = generate_run_id()
await local_provider.init()
entries = [
(run_id_0, "key_0", "value_0"),
(run_id_1, "key_1", "value_1"),
(run_id_2, "key_2", "value_2"),
]
for run_id, key, value in entries:
await local_provider.log(run_id, key, value)
logs = await local_provider.get_logs([run_id_0, run_id_1, run_id_2])
assert len(logs) == 3
# Check that logs are returned in the correct order (most recent first if applicable)
for log in logs:
selected_entry = [
entry for entry in entries if entry[0] == log["log_id"]
][0]
assert log["log_id"] == selected_entry[0]
assert log["key"] == selected_entry[1]
assert log["value"] == selected_entry[2]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_log_retrieval_limit(local_provider):
"""Test the max_logs limit parameter works correctly."""
try:
await local_provider.init()
run_ids = []
for i in range(10): # Add 10 entries
run_ids.append(generate_run_id())
await local_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
logs = await local_provider.get_logs(run_ids[0:5])
assert len(logs) == 5 # Ensure only 5 logs are returned
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_specific_run_type_retrieval(local_provider):
"""Test retrieving logs for a specific run type works correctly."""
try:
await local_provider.init()
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
await local_provider.log(
run_id_0, "pipeline_type", "search", is_info_log=True
)
await local_provider.log(run_id_0, "key_0", "value_0")
await local_provider.log(
run_id_1, "pipeline_type", "rag", is_info_log=True
)
await local_provider.log(run_id_1, "key_1", "value_1")
run_info = await local_provider.get_run_info(log_type_filter="search")
logs = await local_provider.get_logs([run.run_id for run in run_info])
assert len(logs) == 1
assert logs[0]["log_id"] == run_id_0
assert logs[0]["key"] == "key_0"
assert logs[0]["value"] == "value_0"
except asyncio.CancelledError:
pass
@pytest.fixture(scope="function")
def postgres_provider():
"""Fixture to create and tear down the PostgresKVLoggingProvider."""
log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
provider = PostgresKVLoggingProvider(
PostgresLoggingConfig(
log_table=log_table, log_info_table=log_info_table
)
)
yield provider
@pytest.mark.asyncio
async def test_postgres_logging(postgres_provider):
"""Test logging and retrieving from the postgres logging provider."""
try:
await postgres_provider.init()
run_id = generate_run_id()
await postgres_provider.log(run_id, "key", "value")
logs = await postgres_provider.get_logs([run_id])
assert len(logs) == 1
assert logs[0]["key"] == "key"
assert logs[0]["value"] == "value"
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_postgres_multiple_log_entries(postgres_provider):
"""Test logging multiple entries and retrieving them."""
try:
await postgres_provider.init()
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
run_id_2 = generate_run_id()
entries = [
(run_id_0, "key_0", "value_0"),
(run_id_1, "key_1", "value_1"),
(run_id_2, "key_2", "value_2"),
]
for run_id, key, value in entries:
await postgres_provider.log(run_id, key, value)
logs = await postgres_provider.get_logs([run_id_0, run_id_1, run_id_2])
assert len(logs) == 3
# Check that logs are returned in the correct order (most recent first if applicable)
for log in logs:
selected_entry = [
entry for entry in entries if entry[0] == log["log_id"]
][0]
assert log["log_id"] == selected_entry[0]
assert log["key"] == selected_entry[1]
assert log["value"] == selected_entry[2]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_postgres_log_retrieval_limit(postgres_provider):
"""Test the max_logs limit parameter works correctly."""
try:
await postgres_provider.init()
run_ids = []
for i in range(10): # Add 10 entries
run_ids.append(generate_run_id())
await postgres_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
logs = await postgres_provider.get_logs(run_ids[:5])
assert len(logs) == 5 # Ensure only 5 logs are returned
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_postgres_specific_run_type_retrieval(postgres_provider):
"""Test retrieving logs for a specific run type works correctly."""
try:
await postgres_provider.init()
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
await postgres_provider.log(
run_id_0, "pipeline_type", "search", is_info_log=True
)
await postgres_provider.log(run_id_0, "key_0", "value_0")
await postgres_provider.log(
run_id_1, "pipeline_type", "rag", is_info_log=True
)
await postgres_provider.log(run_id_1, "key_1", "value_1")
run_info = await postgres_provider.get_run_info(
log_type_filter="search"
)
logs = await postgres_provider.get_logs(
[run.run_id for run in run_info]
)
assert len(logs) == 1
assert logs[0]["log_id"] == run_id_0
assert logs[0]["key"] == "key_0"
assert logs[0]["value"] == "value_0"
except asyncio.CancelledError:
pass
@pytest.fixture(scope="function")
def redis_provider():
"""Fixture to create and tear down the RedisKVLoggingProvider."""
log_table = f"logs_{str(uuid.uuid4()).replace('-', '_')}"
log_info_table = f"log_info_{str(uuid.uuid4()).replace('-', '_')}"
provider = RedisKVLoggingProvider(
RedisLoggingConfig(log_table=log_table, log_info_table=log_info_table)
)
yield provider
provider.close()
@pytest.mark.asyncio
async def test_redis_logging(redis_provider):
"""Test logging and retrieving from the Redis logging provider."""
try:
run_id = generate_run_id()
await redis_provider.log(run_id, "key", "value")
logs = await redis_provider.get_logs([run_id])
assert len(logs) == 1
assert logs[0]["key"] == "key"
assert logs[0]["value"] == "value"
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_redis_multiple_log_entries(redis_provider):
"""Test logging multiple entries and retrieving them."""
try:
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
run_id_2 = generate_run_id()
entries = [
(run_id_0, "key_0", "value_0"),
(run_id_1, "key_1", "value_1"),
(run_id_2, "key_2", "value_2"),
]
for run_id, key, value in entries:
await redis_provider.log(run_id, key, value)
logs = await redis_provider.get_logs([run_id_0, run_id_1, run_id_2])
assert len(logs) == 3
# Check that logs are returned in the correct order (most recent first if applicable)
for log in logs:
selected_entry = [
entry for entry in entries if entry[0] == log["log_id"]
][0]
assert log["log_id"] == selected_entry[0]
assert log["key"] == selected_entry[1]
assert log["value"] == selected_entry[2]
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_redis_log_retrieval_limit(redis_provider):
"""Test the max_logs limit parameter works correctly."""
try:
run_ids = []
for i in range(10): # Add 10 entries
run_ids.append(generate_run_id())
await redis_provider.log(run_ids[-1], f"key_{i}", f"value_{i}")
logs = await redis_provider.get_logs(run_ids[0:5])
assert len(logs) == 5 # Ensure only 5 logs are returned
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_redis_specific_run_type_retrieval(redis_provider):
"""Test retrieving logs for a specific run type works correctly."""
try:
run_id_0 = generate_run_id()
run_id_1 = generate_run_id()
await redis_provider.log(
run_id_0, "pipeline_type", "search", is_info_log=True
)
await redis_provider.log(run_id_0, "key_0", "value_0")
await redis_provider.log(
run_id_1, "pipeline_type", "rag", is_info_log=True
)
await redis_provider.log(run_id_1, "key_1", "value_1")
run_info = await redis_provider.get_run_info(log_type_filter="search")
logs = await redis_provider.get_logs([run.run_id for run in run_info])
assert len(logs) == 1
assert logs[0]["log_id"] == run_id_0
assert logs[0]["key"] == "key_0"
assert logs[0]["value"] == "value_0"
except asyncio.CancelledError:
pass