about summary refs log tree commit diff
path: root/R2R/r2r/main/assembly/config.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/assembly/config.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/assembly/config.py')
-rwxr-xr-xR2R/r2r/main/assembly/config.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/config.py b/R2R/r2r/main/assembly/config.py
new file mode 100755
index 00000000..d52c4561
--- /dev/null
+++ b/R2R/r2r/main/assembly/config.py
@@ -0,0 +1,167 @@
+import json
+import logging
+import os
+from enum import Enum
+from typing import Any
+
+from ...base.abstractions.document import DocumentType
+from ...base.abstractions.llm import GenerationConfig
+from ...base.logging.kv_logger import LoggingConfig
+from ...base.providers.embedding_provider import EmbeddingConfig
+from ...base.providers.eval_provider import EvalConfig
+from ...base.providers.kg_provider import KGConfig
+from ...base.providers.llm_provider import LLMConfig
+from ...base.providers.prompt_provider import PromptConfig
+from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig
+
+logger = logging.getLogger(__name__)
+
+
+class R2RConfig:
+    REQUIRED_KEYS: dict[str, list] = {
+        "app": ["max_file_size_in_mb"],
+        "embedding": [
+            "provider",
+            "base_model",
+            "base_dimension",
+            "batch_size",
+            "text_splitter",
+        ],
+        "eval": ["llm"],
+        "kg": [
+            "provider",
+            "batch_size",
+            "kg_extraction_config",
+            "text_splitter",
+        ],
+        "ingestion": ["excluded_parsers"],
+        "completions": ["provider"],
+        "logging": ["provider", "log_table"],
+        "prompt": ["provider"],
+        "vector_database": ["provider"],
+    }
+    app: dict[str, Any]
+    embedding: EmbeddingConfig
+    completions: LLMConfig
+    logging: LoggingConfig
+    prompt: PromptConfig
+    vector_database: VectorDBConfig
+
+    def __init__(self, config_data: dict[str, Any]):
+        # Load the default configuration
+        default_config = self.load_default_config()
+
+        # Override the default configuration with the passed configuration
+        for key in config_data:
+            if key in default_config:
+                default_config[key].update(config_data[key])
+            else:
+                default_config[key] = config_data[key]
+
+        # Validate and set the configuration
+        for section, keys in R2RConfig.REQUIRED_KEYS.items():
+            # Check the keys when provider is set
+            # TODO - Clean up robust null checks
+            if "provider" in default_config[section] and (
+                default_config[section]["provider"] is not None
+                and default_config[section]["provider"] != "None"
+                and default_config[section]["provider"] != "null"
+            ):
+                self._validate_config_section(default_config, section, keys)
+            setattr(self, section, default_config[section])
+
+        self.app = self.app  # for type hinting
+        self.ingestion = self.ingestion  # for type hinting
+        self.ingestion["excluded_parsers"] = [
+            DocumentType(k) for k in self.ingestion["excluded_parsers"]
+        ]
+        # override GenerationConfig defaults
+        GenerationConfig.set_default(
+            **self.completions.get("generation_config", {})
+        )
+        self.embedding = EmbeddingConfig.create(**self.embedding)
+        self.kg = KGConfig.create(**self.kg)
+        eval_llm = self.eval.pop("llm", None)
+        self.eval = EvalConfig.create(
+            **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None
+        )
+        self.completions = LLMConfig.create(**self.completions)
+        self.logging = LoggingConfig.create(**self.logging)
+        self.prompt = PromptConfig.create(**self.prompt)
+        self.vector_database = VectorDBConfig.create(**self.vector_database)
+
+    def _validate_config_section(
+        self, config_data: dict[str, Any], section: str, keys: list
+    ):
+        if section not in config_data:
+            raise ValueError(f"Missing '{section}' section in config")
+        if not all(key in config_data[section] for key in keys):
+            raise ValueError(f"Missing required keys in '{section}' config")
+
+    @classmethod
+    def from_json(cls, config_path: str = None) -> "R2RConfig":
+        if config_path is None:
+            # Get the root directory of the project
+            file_dir = os.path.dirname(os.path.abspath(__file__))
+            config_path = os.path.join(
+                file_dir, "..", "..", "..", "config.json"
+            )
+
+        # Load configuration from JSON file
+        with open(config_path) as f:
+            config_data = json.load(f)
+
+        return cls(config_data)
+
+    def to_json(self):
+        config_data = {
+            section: self._serialize_config(getattr(self, section))
+            for section in R2RConfig.REQUIRED_KEYS.keys()
+        }
+        return json.dumps(config_data)
+
+    def save_to_redis(self, redis_client: Any, key: str):
+        redis_client.set(f"R2RConfig:{key}", self.to_json())
+
+    @classmethod
+    def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig":
+        config_data = redis_client.get(f"R2RConfig:{key}")
+        if config_data is None:
+            raise ValueError(
+                f"Configuration not found in Redis with key '{key}'"
+            )
+        config_data = json.loads(config_data)
+        # config_data["ingestion"]["selected_parsers"] = {
+        #     DocumentType(k): v
+        #     for k, v in config_data["ingestion"]["selected_parsers"].items()
+        # }
+        return cls(config_data)
+
+    @classmethod
+    def load_default_config(cls) -> dict:
+        # Get the root directory of the project
+        file_dir = os.path.dirname(os.path.abspath(__file__))
+        default_config_path = os.path.join(
+            file_dir, "..", "..", "..", "config.json"
+        )
+        # Load default configuration from JSON file
+        with open(default_config_path) as f:
+            return json.load(f)
+
+    @staticmethod
+    def _serialize_config(config_section: Any) -> dict:
+        # TODO - Make this approach cleaner
+        if isinstance(config_section, ProviderConfig):
+            config_section = config_section.dict()
+        filtered_result = {}
+        for k, v in config_section.items():
+            if isinstance(k, Enum):
+                k = k.value
+            if isinstance(v, dict):
+                formatted_v = {
+                    k2.value if isinstance(k2, Enum) else k2: v2
+                    for k2, v2 in v.items()
+                }
+                v = formatted_v
+            filtered_result[k] = v
+        return filtered_result