import json
import logging
import os
from enum import Enum
from typing import Any
from ...base.abstractions.document import DocumentType
from ...base.abstractions.llm import GenerationConfig
from ...base.logging.kv_logger import LoggingConfig
from ...base.providers.embedding_provider import EmbeddingConfig
from ...base.providers.eval_provider import EvalConfig
from ...base.providers.kg_provider import KGConfig
from ...base.providers.llm_provider import LLMConfig
from ...base.providers.prompt_provider import PromptConfig
from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig
logger = logging.getLogger(__name__)
class R2RConfig:
REQUIRED_KEYS: dict[str, list] = {
"app": ["max_file_size_in_mb"],
"embedding": [
"provider",
"base_model",
"base_dimension",
"batch_size",
"text_splitter",
],
"eval": ["llm"],
"kg": [
"provider",
"batch_size",
"kg_extraction_config",
"text_splitter",
],
"ingestion": ["excluded_parsers"],
"completions": ["provider"],
"logging": ["provider", "log_table"],
"prompt": ["provider"],
"vector_database": ["provider"],
}
app: dict[str, Any]
embedding: EmbeddingConfig
completions: LLMConfig
logging: LoggingConfig
prompt: PromptConfig
vector_database: VectorDBConfig
def __init__(self, config_data: dict[str, Any]):
# Load the default configuration
default_config = self.load_default_config()
# Override the default configuration with the passed configuration
for key in config_data:
if key in default_config:
default_config[key].update(config_data[key])
else:
default_config[key] = config_data[key]
# Validate and set the configuration
for section, keys in R2RConfig.REQUIRED_KEYS.items():
# Check the keys when provider is set
# TODO - Clean up robust null checks
if "provider" in default_config[section] and (
default_config[section]["provider"] is not None
and default_config[section]["provider"] != "None"
and default_config[section]["provider"] != "null"
):
self._validate_config_section(default_config, section, keys)
setattr(self, section, default_config[section])
self.app = self.app # for type hinting
self.ingestion = self.ingestion # for type hinting
self.ingestion["excluded_parsers"] = [
DocumentType(k) for k in self.ingestion["excluded_parsers"]
]
# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completions.get("generation_config", {})
)
self.embedding = EmbeddingConfig.create(**self.embedding)
self.kg = KGConfig.create(**self.kg)
eval_llm = self.eval.pop("llm", None)
self.eval = EvalConfig.create(
**self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None
)
self.completions = LLMConfig.create(**self.completions)
self.logging = LoggingConfig.create(**self.logging)
self.prompt = PromptConfig.create(**self.prompt)
self.vector_database = VectorDBConfig.create(**self.vector_database)
def _validate_config_section(
self, config_data: dict[str, Any], section: str, keys: list
):
if section not in config_data:
raise ValueError(f"Missing '{section}' section in config")
if not all(key in config_data[section] for key in keys):
raise ValueError(f"Missing required keys in '{section}' config")
@classmethod
def from_json(cls, config_path: str = None) -> "R2RConfig":
if config_path is None:
# Get the root directory of the project
file_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
file_dir, "..", "..", "..", "config.json"
)
# Load configuration from JSON file
with open(config_path) as f:
config_data = json.load(f)
return cls(config_data)
def to_json(self):
config_data = {
section: self._serialize_config(getattr(self, section))
for section in R2RConfig.REQUIRED_KEYS.keys()
}
return json.dumps(config_data)
def save_to_redis(self, redis_client: Any, key: str):
redis_client.set(f"R2RConfig:{key}", self.to_json())
@classmethod
def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig":
config_data = redis_client.get(f"R2RConfig:{key}")
if config_data is None:
raise ValueError(
f"Configuration not found in Redis with key '{key}'"
)
config_data = json.loads(config_data)
# config_data["ingestion"]["selected_parsers"] = {
# DocumentType(k): v
# for k, v in config_data["ingestion"]["selected_parsers"].items()
# }
return cls(config_data)
@classmethod
def load_default_config(cls) -> dict:
# Get the root directory of the project
file_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(
file_dir, "..", "..", "..", "config.json"
)
# Load default configuration from JSON file
with open(default_config_path) as f:
return json.load(f)
@staticmethod
def _serialize_config(config_section: Any) -> dict:
# TODO - Make this approach cleaner
if isinstance(config_section, ProviderConfig):
config_section = config_section.dict()
filtered_result = {}
for k, v in config_section.items():
if isinstance(k, Enum):
k = k.value
if isinstance(v, dict):
formatted_v = {
k2.value if isinstance(k2, Enum) else k2: v2
for k2, v2 in v.items()
}
v = formatted_v
filtered_result[k] = v
return filtered_result