aboutsummaryrefslogtreecommitdiff
import json
import logging
import os
from enum import Enum
from typing import Any

from ...base.abstractions.document import DocumentType
from ...base.abstractions.llm import GenerationConfig
from ...base.logging.kv_logger import LoggingConfig
from ...base.providers.embedding_provider import EmbeddingConfig
from ...base.providers.eval_provider import EvalConfig
from ...base.providers.kg_provider import KGConfig
from ...base.providers.llm_provider import LLMConfig
from ...base.providers.prompt_provider import PromptConfig
from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig

logger = logging.getLogger(__name__)


class R2RConfig:
    REQUIRED_KEYS: dict[str, list] = {
        "app": ["max_file_size_in_mb"],
        "embedding": [
            "provider",
            "base_model",
            "base_dimension",
            "batch_size",
            "text_splitter",
        ],
        "eval": ["llm"],
        "kg": [
            "provider",
            "batch_size",
            "kg_extraction_config",
            "text_splitter",
        ],
        "ingestion": ["excluded_parsers"],
        "completions": ["provider"],
        "logging": ["provider", "log_table"],
        "prompt": ["provider"],
        "vector_database": ["provider"],
    }
    app: dict[str, Any]
    embedding: EmbeddingConfig
    completions: LLMConfig
    logging: LoggingConfig
    prompt: PromptConfig
    vector_database: VectorDBConfig

    def __init__(self, config_data: dict[str, Any]):
        # Load the default configuration
        default_config = self.load_default_config()

        # Override the default configuration with the passed configuration
        for key in config_data:
            if key in default_config:
                default_config[key].update(config_data[key])
            else:
                default_config[key] = config_data[key]

        # Validate and set the configuration
        for section, keys in R2RConfig.REQUIRED_KEYS.items():
            # Check the keys when provider is set
            # TODO - Clean up robust null checks
            if "provider" in default_config[section] and (
                default_config[section]["provider"] is not None
                and default_config[section]["provider"] != "None"
                and default_config[section]["provider"] != "null"
            ):
                self._validate_config_section(default_config, section, keys)
            setattr(self, section, default_config[section])

        self.app = self.app  # for type hinting
        self.ingestion = self.ingestion  # for type hinting
        self.ingestion["excluded_parsers"] = [
            DocumentType(k) for k in self.ingestion["excluded_parsers"]
        ]
        # override GenerationConfig defaults
        GenerationConfig.set_default(
            **self.completions.get("generation_config", {})
        )
        self.embedding = EmbeddingConfig.create(**self.embedding)
        self.kg = KGConfig.create(**self.kg)
        eval_llm = self.eval.pop("llm", None)
        self.eval = EvalConfig.create(
            **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None
        )
        self.completions = LLMConfig.create(**self.completions)
        self.logging = LoggingConfig.create(**self.logging)
        self.prompt = PromptConfig.create(**self.prompt)
        self.vector_database = VectorDBConfig.create(**self.vector_database)

    def _validate_config_section(
        self, config_data: dict[str, Any], section: str, keys: list
    ):
        if section not in config_data:
            raise ValueError(f"Missing '{section}' section in config")
        if not all(key in config_data[section] for key in keys):
            raise ValueError(f"Missing required keys in '{section}' config")

    @classmethod
    def from_json(cls, config_path: str = None) -> "R2RConfig":
        if config_path is None:
            # Get the root directory of the project
            file_dir = os.path.dirname(os.path.abspath(__file__))
            config_path = os.path.join(
                file_dir, "..", "..", "..", "config.json"
            )

        # Load configuration from JSON file
        with open(config_path) as f:
            config_data = json.load(f)

        return cls(config_data)

    def to_json(self):
        config_data = {
            section: self._serialize_config(getattr(self, section))
            for section in R2RConfig.REQUIRED_KEYS.keys()
        }
        return json.dumps(config_data)

    def save_to_redis(self, redis_client: Any, key: str):
        redis_client.set(f"R2RConfig:{key}", self.to_json())

    @classmethod
    def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig":
        config_data = redis_client.get(f"R2RConfig:{key}")
        if config_data is None:
            raise ValueError(
                f"Configuration not found in Redis with key '{key}'"
            )
        config_data = json.loads(config_data)
        # config_data["ingestion"]["selected_parsers"] = {
        #     DocumentType(k): v
        #     for k, v in config_data["ingestion"]["selected_parsers"].items()
        # }
        return cls(config_data)

    @classmethod
    def load_default_config(cls) -> dict:
        # Get the root directory of the project
        file_dir = os.path.dirname(os.path.abspath(__file__))
        default_config_path = os.path.join(
            file_dir, "..", "..", "..", "config.json"
        )
        # Load default configuration from JSON file
        with open(default_config_path) as f:
            return json.load(f)

    @staticmethod
    def _serialize_config(config_section: Any) -> dict:
        # TODO - Make this approach cleaner
        if isinstance(config_section, ProviderConfig):
            config_section = config_section.dict()
        filtered_result = {}
        for k, v in config_section.items():
            if isinstance(k, Enum):
                k = k.value
            if isinstance(v, dict):
                formatted_v = {
                    k2.value if isinstance(k2, Enum) else k2: v2
                    for k2, v2 in v.items()
                }
                v = formatted_v
            filtered_result[k] = v
        return filtered_result