diff options
Diffstat (limited to 'R2R/tests/test_config.py')
-rwxr-xr-x | R2R/tests/test_config.py | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/R2R/tests/test_config.py b/R2R/tests/test_config.py new file mode 100755 index 00000000..5e60833c --- /dev/null +++ b/R2R/tests/test_config.py @@ -0,0 +1,187 @@ +import asyncio +import json +from unittest.mock import Mock, mock_open, patch + +import pytest + +from r2r import DocumentType, R2RConfig + + +@pytest.fixture(scope="session", autouse=True) +def event_loop_policy(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + + +@pytest.fixture(scope="function") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + asyncio.set_event_loop(None) + + +@pytest.fixture(scope="session", autouse=True) +async def cleanup_tasks(): + yield + for task in asyncio.all_tasks(): + if task is not asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.fixture +def mock_bad_file(): + mock_data = json.dumps({}) + with patch("builtins.open", mock_open(read_data=mock_data)) as m: + yield m + + +@pytest.fixture +def mock_file(): + mock_data = json.dumps( + { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "eval": {"llm": {"provider": "local"}}, + "ingestion": {"excluded_parsers": {}}, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + ) + with patch("builtins.open", mock_open(read_data=mock_data)) as m: + yield m + + +@pytest.mark.asyncio +async def test_r2r_config_loading_required_keys(mock_bad_file): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") + + +@pytest.mark.asyncio +async def test_r2r_config_loading(mock_file): + config = R2RConfig.from_json("config.json") + assert ( + config.embedding.provider == "example_provider" + ), "Provider should match the mock data" + + +@pytest.fixture +def mock_redis_client(): + client = Mock() + return client + + +def test_r2r_config_serialization(mock_file, mock_redis_client): + config = R2RConfig.from_json("config.json") + config.save_to_redis(mock_redis_client, "test_key") + mock_redis_client.set.assert_called_once() + saved_data = json.loads(mock_redis_client.set.call_args[0][1]) + assert saved_data["app"]["max_file_size_in_mb"] == 128 + + +def test_r2r_config_deserialization(mock_file, mock_redis_client): + config_data = { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "eval": {"llm": {"provider": "local"}}, + "ingestion": {"excluded_parsers": ["pdf"]}, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + mock_redis_client.get.return_value = json.dumps(config_data) + config = R2RConfig.load_from_redis(mock_redis_client, "test_key") + assert config.app["max_file_size_in_mb"] == 128 + assert DocumentType.PDF in config.ingestion["excluded_parsers"] + + +def test_r2r_config_missing_section(): + invalid_data = { + "embedding": { + "provider": "example_provider", + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + } + } + with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") + + +def test_r2r_config_missing_required_key(): + invalid_data = { + "app": {"max_file_size_in_mb": 128}, + "embedding": { + "base_model": "model", + "base_dimension": 128, + "batch_size": 16, + "text_splitter": "default", + }, + "kg": { + "provider": "None", + "batch_size": 1, + "text_splitter": { + "type": "recursive_character", + "chunk_size": 2048, + "chunk_overlap": 0, + }, + }, + "completions": {"provider": "lm_provider"}, + "logging": { + "provider": "local", + "log_table": "logs", + "log_info_table": "log_info", + }, + "prompt": {"provider": "prompt_provider"}, + "vector_database": {"provider": "vector_db"}, + } + with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))): + with pytest.raises(KeyError): + R2RConfig.from_json("config.json") |