aboutsummaryrefslogtreecommitdiff
import asyncio
import json
from unittest.mock import Mock, mock_open, patch

import pytest

from r2r import DocumentType, R2RConfig


@pytest.fixture(scope="session", autouse=True)
def event_loop_policy():
    asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())


@pytest.fixture(scope="function")
def event_loop():
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()
    asyncio.set_event_loop(None)


@pytest.fixture(scope="session", autouse=True)
async def cleanup_tasks():
    yield
    for task in asyncio.all_tasks():
        if task is not asyncio.current_task():
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                pass


@pytest.fixture
def mock_bad_file():
    mock_data = json.dumps({})
    with patch("builtins.open", mock_open(read_data=mock_data)) as m:
        yield m


@pytest.fixture
def mock_file():
    mock_data = json.dumps(
        {
            "app": {"max_file_size_in_mb": 128},
            "embedding": {
                "provider": "example_provider",
                "base_model": "model",
                "base_dimension": 128,
                "batch_size": 16,
                "text_splitter": "default",
            },
            "kg": {
                "provider": "None",
                "batch_size": 1,
                "text_splitter": {
                    "type": "recursive_character",
                    "chunk_size": 2048,
                    "chunk_overlap": 0,
                },
            },
            "eval": {"llm": {"provider": "local"}},
            "ingestion": {"excluded_parsers": {}},
            "completions": {"provider": "lm_provider"},
            "logging": {
                "provider": "local",
                "log_table": "logs",
                "log_info_table": "log_info",
            },
            "prompt": {"provider": "prompt_provider"},
            "vector_database": {"provider": "vector_db"},
        }
    )
    with patch("builtins.open", mock_open(read_data=mock_data)) as m:
        yield m


@pytest.mark.asyncio
async def test_r2r_config_loading_required_keys(mock_bad_file):
    with pytest.raises(KeyError):
        R2RConfig.from_json("config.json")


@pytest.mark.asyncio
async def test_r2r_config_loading(mock_file):
    config = R2RConfig.from_json("config.json")
    assert (
        config.embedding.provider == "example_provider"
    ), "Provider should match the mock data"


@pytest.fixture
def mock_redis_client():
    client = Mock()
    return client


def test_r2r_config_serialization(mock_file, mock_redis_client):
    config = R2RConfig.from_json("config.json")
    config.save_to_redis(mock_redis_client, "test_key")
    mock_redis_client.set.assert_called_once()
    saved_data = json.loads(mock_redis_client.set.call_args[0][1])
    assert saved_data["app"]["max_file_size_in_mb"] == 128


def test_r2r_config_deserialization(mock_file, mock_redis_client):
    config_data = {
        "app": {"max_file_size_in_mb": 128},
        "embedding": {
            "provider": "example_provider",
            "base_model": "model",
            "base_dimension": 128,
            "batch_size": 16,
            "text_splitter": "default",
        },
        "kg": {
            "provider": "None",
            "batch_size": 1,
            "text_splitter": {
                "type": "recursive_character",
                "chunk_size": 2048,
                "chunk_overlap": 0,
            },
        },
        "eval": {"llm": {"provider": "local"}},
        "ingestion": {"excluded_parsers": ["pdf"]},
        "completions": {"provider": "lm_provider"},
        "logging": {
            "provider": "local",
            "log_table": "logs",
            "log_info_table": "log_info",
        },
        "prompt": {"provider": "prompt_provider"},
        "vector_database": {"provider": "vector_db"},
    }
    mock_redis_client.get.return_value = json.dumps(config_data)
    config = R2RConfig.load_from_redis(mock_redis_client, "test_key")
    assert config.app["max_file_size_in_mb"] == 128
    assert DocumentType.PDF in config.ingestion["excluded_parsers"]


def test_r2r_config_missing_section():
    invalid_data = {
        "embedding": {
            "provider": "example_provider",
            "base_model": "model",
            "base_dimension": 128,
            "batch_size": 16,
            "text_splitter": "default",
        }
    }
    with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
        with pytest.raises(KeyError):
            R2RConfig.from_json("config.json")


def test_r2r_config_missing_required_key():
    invalid_data = {
        "app": {"max_file_size_in_mb": 128},
        "embedding": {
            "base_model": "model",
            "base_dimension": 128,
            "batch_size": 16,
            "text_splitter": "default",
        },
        "kg": {
            "provider": "None",
            "batch_size": 1,
            "text_splitter": {
                "type": "recursive_character",
                "chunk_size": 2048,
                "chunk_overlap": 0,
            },
        },
        "completions": {"provider": "lm_provider"},
        "logging": {
            "provider": "local",
            "log_table": "logs",
            "log_info_table": "log_info",
        },
        "prompt": {"provider": "prompt_provider"},
        "vector_database": {"provider": "vector_db"},
    }
    with patch("builtins.open", mock_open(read_data=json.dumps(invalid_data))):
        with pytest.raises(KeyError):
            R2RConfig.from_json("config.json")