import asyncio
import json
from unittest.mock import Mock, mock_open, patch
import pytest
from r2r import DocumentType, R2RConfig
@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
@pytest.fixture(scope="function")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
asyncio.set_event_loop(None)
@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
yield
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.fixture
def mock_bad_file():
mock_data = json.dumps({})
with patch("builtins.open", mock_open(read_data=mock_data)) as m:
yield m
@pytest.fixture
def mock_file():
mock_data = json.dumps(
{
"app": {"max_file_size_in_mb": 128},
"embedding": {
"provider": "example_provider",
"base_model": "model",
"base_dimension": 128,
"batch_size": 16,
"text_splitter": "default",
},
"kg": {
"provider": "None",
"batch_size": 1,
"text_splitter": {
"type": "recursive_character",
"chunk_size": 2048,
"chunk_overlap": 0,
},
},
"eval": {"llm": {"provider": "local"}},
"ingestion": {"excluded_parsers": {}},
"completions": {"provider": "lm_provider"},
"logging": {
"provider": "local",
"log_table": "logs",
"log_info_table": "log_info",
},
"prompt": {"provider": "prompt_provider"},
"vector_database": {"provider": "vector_db"},
}
)
with patch("builtins.open", mock_open(read_data=mock_data)) as m:
yield m
@pytest.mark.asyncio
async def test_r2r_config_loading_required_keys(mock_bad_file):
with pytest.raises(KeyError):
R2RConfig.from_json("config.json")
@pytest.mark.asyncio
async def test_r2r_config_loading(mock_file):
config = R2RConfig.from_json("config.json")
assert (
config.embedding.provider == "example_provider"
), "Provider should match the mock data"
@pytest.fixture
def mock_redis_client():
client = Mock()
return client
def test_r2r_config_serialization(mock_file, mock_redis_client):
config = R2RConfig.from_json("config.json")
config.save_to_redis(mock_redis_client, "test_key")
mock_redis_client.set.assert_called_once()
saved_data = json.loads(mock_redis_client.set.call_args[0][1])
assert saved_data["app"]["max_file_size_in_mb"] == 128
def test_r2r_config_deserialization(mock_file, mock_redis_client):
config_data = {
"app": {"max_file_size_in_mb": 128},
"embedding": {
"provider": "example_provider",
"base_model": "model",
"base_dimension": 128,
"batch_size": 16,
"text_splitter": "default",
},
"kg": {
"provider": "None",
"batch_size": 1,
"text_splitter": {
"type": "recursive_character",
"chunk_size": 2048,
"chunk_overlap": 0,
},
},
"eval": {"llm": {"provider": "local"}},
"ingestion": {"excluded_parsers": ["pdf"]},
"completions": {"provider": "lm_provider"},
"logging": {
"provider": "local",
"log_table": "logs",
"log_info_table": "log_info",
},
"prompt": {"provider": "prompt_provider"},
"vector_database": {"provider": "vector_db"},
}
mock_redis_client.get.return_value = json.dumps(config_data)
config = R2RConfig.load_from_redis(mock_redis_client, "test_key")
assert config.app["max_file_size_in_mb"] == 128
assert DocumentType.PDF in config.ingestion["excluded_parsers"]
def test_r2r_config_missing_section():
invalid_data = {
"embedding": {
"provider": "example_provider",
"base_model": "model",
"base_dimension": 128,
"batch_size": 16,
"text_splitter": "default",
}
}
with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
with pytest.raises(KeyError):
R2RConfig.from_json("config.json")
def test_r2r_config_missing_required_key():
invalid_data = {
"app": {"max_file_size_in_mb": 128},
"embedding": {
"base_model": "model",
"base_dimension": 128,
"batch_size": 16,
"text_splitter": "default",
},
"kg": {
"provider": "None",
"batch_size": 1,
"text_splitter": {
"type": "recursive_character",
"chunk_size": 2048,
"chunk_overlap": 0,
},
},
"completions": {"provider": "lm_provider"},
"logging": {
"provider": "local",
"log_table": "logs",
"log_info_table": "log_info",
},
"prompt": {"provider": "prompt_provider"},
"vector_database": {"provider": "vector_db"},
}
with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
with pytest.raises(KeyError):
R2RConfig.from_json("config.json")