diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/assembly/config.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/main/assembly/config.py')
-rwxr-xr-x | R2R/r2r/main/assembly/config.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/config.py b/R2R/r2r/main/assembly/config.py new file mode 100755 index 00000000..d52c4561 --- /dev/null +++ b/R2R/r2r/main/assembly/config.py @@ -0,0 +1,167 @@ +import json +import logging +import os +from enum import Enum +from typing import Any + +from ...base.abstractions.document import DocumentType +from ...base.abstractions.llm import GenerationConfig +from ...base.logging.kv_logger import LoggingConfig +from ...base.providers.embedding_provider import EmbeddingConfig +from ...base.providers.eval_provider import EvalConfig +from ...base.providers.kg_provider import KGConfig +from ...base.providers.llm_provider import LLMConfig +from ...base.providers.prompt_provider import PromptConfig +from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig + +logger = logging.getLogger(__name__) + + +class R2RConfig: + REQUIRED_KEYS: dict[str, list] = { + "app": ["max_file_size_in_mb"], + "embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "text_splitter", + ], + "eval": ["llm"], + "kg": [ + "provider", + "batch_size", + "kg_extraction_config", + "text_splitter", + ], + "ingestion": ["excluded_parsers"], + "completions": ["provider"], + "logging": ["provider", "log_table"], + "prompt": ["provider"], + "vector_database": ["provider"], + } + app: dict[str, Any] + embedding: EmbeddingConfig + completions: LLMConfig + logging: LoggingConfig + prompt: PromptConfig + vector_database: VectorDBConfig + + def __init__(self, config_data: dict[str, Any]): + # Load the default configuration + default_config = self.load_default_config() + + # Override the default configuration with the passed configuration + for key in config_data: + if key in default_config: + default_config[key].update(config_data[key]) + else: + default_config[key] = config_data[key] + + # Validate and set the configuration + for section, keys in R2RConfig.REQUIRED_KEYS.items(): + # Check the keys when provider is set + # TODO - Clean up robust null checks + if "provider" in default_config[section] and ( + default_config[section]["provider"] is not None + and default_config[section]["provider"] != "None" + and default_config[section]["provider"] != "null" + ): + self._validate_config_section(default_config, section, keys) + setattr(self, section, default_config[section]) + + self.app = self.app # for type hinting + self.ingestion = self.ingestion # for type hinting + self.ingestion["excluded_parsers"] = [ + DocumentType(k) for k in self.ingestion["excluded_parsers"] + ] + # override GenerationConfig defaults + GenerationConfig.set_default( + **self.completions.get("generation_config", {}) + ) + self.embedding = EmbeddingConfig.create(**self.embedding) + self.kg = KGConfig.create(**self.kg) + eval_llm = self.eval.pop("llm", None) + self.eval = EvalConfig.create( + **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None + ) + self.completions = LLMConfig.create(**self.completions) + self.logging = LoggingConfig.create(**self.logging) + self.prompt = PromptConfig.create(**self.prompt) + self.vector_database = VectorDBConfig.create(**self.vector_database) + + def _validate_config_section( + self, config_data: dict[str, Any], section: str, keys: list + ): + if section not in config_data: + raise ValueError(f"Missing '{section}' section in config") + if not all(key in config_data[section] for key in keys): + raise ValueError(f"Missing required keys in '{section}' config") + + @classmethod + def from_json(cls, config_path: str = None) -> "R2RConfig": + if config_path is None: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + + # Load configuration from JSON file + with open(config_path) as f: + config_data = json.load(f) + + return cls(config_data) + + def to_json(self): + config_data = { + section: self._serialize_config(getattr(self, section)) + for section in R2RConfig.REQUIRED_KEYS.keys() + } + return json.dumps(config_data) + + def save_to_redis(self, redis_client: Any, key: str): + redis_client.set(f"R2RConfig:{key}", self.to_json()) + + @classmethod + def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig": + config_data = redis_client.get(f"R2RConfig:{key}") + if config_data is None: + raise ValueError( + f"Configuration not found in Redis with key '{key}'" + ) + config_data = json.loads(config_data) + # config_data["ingestion"]["selected_parsers"] = { + # DocumentType(k): v + # for k, v in config_data["ingestion"]["selected_parsers"].items() + # } + return cls(config_data) + + @classmethod + def load_default_config(cls) -> dict: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + # Load default configuration from JSON file + with open(default_config_path) as f: + return json.load(f) + + @staticmethod + def _serialize_config(config_section: Any) -> dict: + # TODO - Make this approach cleaner + if isinstance(config_section, ProviderConfig): + config_section = config_section.dict() + filtered_result = {} + for k, v in config_section.items(): + if isinstance(k, Enum): + k = k.value + if isinstance(v, dict): + formatted_v = { + k2.value if isinstance(k2, Enum) else k2: v2 + for k2, v2 in v.items() + } + v = formatted_v + filtered_result[k] = v + return filtered_result |