about summary refs log tree commit diff
path: root/R2R/tests/test_config.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/tests/test_config.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/tests/test_config.py')
-rwxr-xr-xR2R/tests/test_config.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/R2R/tests/test_config.py b/R2R/tests/test_config.py
new file mode 100755
index 00000000..5e60833c
--- /dev/null
+++ b/R2R/tests/test_config.py
@@ -0,0 +1,187 @@
+import asyncio
+import json
+from unittest.mock import Mock, mock_open, patch
+
+import pytest
+
+from r2r import DocumentType, R2RConfig
+
+
+@pytest.fixture(scope="session", autouse=True)
+def event_loop_policy():
+    asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
+@pytest.fixture(scope="function")
+def event_loop():
+    loop = asyncio.get_event_loop_policy().new_event_loop()
+    yield loop
+    loop.close()
+    asyncio.set_event_loop(None)
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def cleanup_tasks():
+    yield
+    for task in asyncio.all_tasks():
+        if task is not asyncio.current_task():
+            task.cancel()
+            try:
+                await task
+            except asyncio.CancelledError:
+                pass
+
+
+@pytest.fixture
+def mock_bad_file():
+    mock_data = json.dumps({})
+    with patch("builtins.open", mock_open(read_data=mock_data)) as m:
+        yield m
+
+
+@pytest.fixture
+def mock_file():
+    mock_data = json.dumps(
+        {
+            "app": {"max_file_size_in_mb": 128},
+            "embedding": {
+                "provider": "example_provider",
+                "base_model": "model",
+                "base_dimension": 128,
+                "batch_size": 16,
+                "text_splitter": "default",
+            },
+            "kg": {
+                "provider": "None",
+                "batch_size": 1,
+                "text_splitter": {
+                    "type": "recursive_character",
+                    "chunk_size": 2048,
+                    "chunk_overlap": 0,
+                },
+            },
+            "eval": {"llm": {"provider": "local"}},
+            "ingestion": {"excluded_parsers": {}},
+            "completions": {"provider": "lm_provider"},
+            "logging": {
+                "provider": "local",
+                "log_table": "logs",
+                "log_info_table": "log_info",
+            },
+            "prompt": {"provider": "prompt_provider"},
+            "vector_database": {"provider": "vector_db"},
+        }
+    )
+    with patch("builtins.open", mock_open(read_data=mock_data)) as m:
+        yield m
+
+
+@pytest.mark.asyncio
+async def test_r2r_config_loading_required_keys(mock_bad_file):
+    with pytest.raises(KeyError):
+        R2RConfig.from_json("config.json")
+
+
+@pytest.mark.asyncio
+async def test_r2r_config_loading(mock_file):
+    config = R2RConfig.from_json("config.json")
+    assert (
+        config.embedding.provider == "example_provider"
+    ), "Provider should match the mock data"
+
+
+@pytest.fixture
+def mock_redis_client():
+    client = Mock()
+    return client
+
+
+def test_r2r_config_serialization(mock_file, mock_redis_client):
+    config = R2RConfig.from_json("config.json")
+    config.save_to_redis(mock_redis_client, "test_key")
+    mock_redis_client.set.assert_called_once()
+    saved_data = json.loads(mock_redis_client.set.call_args[0][1])
+    assert saved_data["app"]["max_file_size_in_mb"] == 128
+
+
+def test_r2r_config_deserialization(mock_file, mock_redis_client):
+    config_data = {
+        "app": {"max_file_size_in_mb": 128},
+        "embedding": {
+            "provider": "example_provider",
+            "base_model": "model",
+            "base_dimension": 128,
+            "batch_size": 16,
+            "text_splitter": "default",
+        },
+        "kg": {
+            "provider": "None",
+            "batch_size": 1,
+            "text_splitter": {
+                "type": "recursive_character",
+                "chunk_size": 2048,
+                "chunk_overlap": 0,
+            },
+        },
+        "eval": {"llm": {"provider": "local"}},
+        "ingestion": {"excluded_parsers": ["pdf"]},
+        "completions": {"provider": "lm_provider"},
+        "logging": {
+            "provider": "local",
+            "log_table": "logs",
+            "log_info_table": "log_info",
+        },
+        "prompt": {"provider": "prompt_provider"},
+        "vector_database": {"provider": "vector_db"},
+    }
+    mock_redis_client.get.return_value = json.dumps(config_data)
+    config = R2RConfig.load_from_redis(mock_redis_client, "test_key")
+    assert config.app["max_file_size_in_mb"] == 128
+    assert DocumentType.PDF in config.ingestion["excluded_parsers"]
+
+
+def test_r2r_config_missing_section():
+    invalid_data = {
+        "embedding": {
+            "provider": "example_provider",
+            "base_model": "model",
+            "base_dimension": 128,
+            "batch_size": 16,
+            "text_splitter": "default",
+        }
+    }
+    with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
+        with pytest.raises(KeyError):
+            R2RConfig.from_json("config.json")
+
+
+def test_r2r_config_missing_required_key():
+    invalid_data = {
+        "app": {"max_file_size_in_mb": 128},
+        "embedding": {
+            "base_model": "model",
+            "base_dimension": 128,
+            "batch_size": 16,
+            "text_splitter": "default",
+        },
+        "kg": {
+            "provider": "None",
+            "batch_size": 1,
+            "text_splitter": {
+                "type": "recursive_character",
+                "chunk_size": 2048,
+                "chunk_overlap": 0,
+            },
+        },
+        "completions": {"provider": "lm_provider"},
+        "logging": {
+            "provider": "local",
+            "log_table": "logs",
+            "log_info_table": "log_info",
+        },
+        "prompt": {"provider": "prompt_provider"},
+        "vector_database": {"provider": "vector_db"},
+    }
+    with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
+        with pytest.raises(KeyError):
+            R2RConfig.from_json("config.json")