aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/assembly/config.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/assembly/config.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/main/assembly/config.py')
-rwxr-xr-xR2R/r2r/main/assembly/config.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/R2R/r2r/main/assembly/config.py b/R2R/r2r/main/assembly/config.py
new file mode 100755
index 00000000..d52c4561
--- /dev/null
+++ b/R2R/r2r/main/assembly/config.py
@@ -0,0 +1,167 @@
+import json
+import logging
+import os
+from enum import Enum
+from typing import Any
+
+from ...base.abstractions.document import DocumentType
+from ...base.abstractions.llm import GenerationConfig
+from ...base.logging.kv_logger import LoggingConfig
+from ...base.providers.embedding_provider import EmbeddingConfig
+from ...base.providers.eval_provider import EvalConfig
+from ...base.providers.kg_provider import KGConfig
+from ...base.providers.llm_provider import LLMConfig
+from ...base.providers.prompt_provider import PromptConfig
+from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig
+
+logger = logging.getLogger(__name__)
+
+
+class R2RConfig:
+ REQUIRED_KEYS: dict[str, list] = {
+ "app": ["max_file_size_in_mb"],
+ "embedding": [
+ "provider",
+ "base_model",
+ "base_dimension",
+ "batch_size",
+ "text_splitter",
+ ],
+ "eval": ["llm"],
+ "kg": [
+ "provider",
+ "batch_size",
+ "kg_extraction_config",
+ "text_splitter",
+ ],
+ "ingestion": ["excluded_parsers"],
+ "completions": ["provider"],
+ "logging": ["provider", "log_table"],
+ "prompt": ["provider"],
+ "vector_database": ["provider"],
+ }
+ app: dict[str, Any]
+ embedding: EmbeddingConfig
+ completions: LLMConfig
+ logging: LoggingConfig
+ prompt: PromptConfig
+ vector_database: VectorDBConfig
+
+ def __init__(self, config_data: dict[str, Any]):
+ # Load the default configuration
+ default_config = self.load_default_config()
+
+ # Override the default configuration with the passed configuration
+ for key in config_data:
+ if key in default_config:
+ default_config[key].update(config_data[key])
+ else:
+ default_config[key] = config_data[key]
+
+ # Validate and set the configuration
+ for section, keys in R2RConfig.REQUIRED_KEYS.items():
+ # Check the keys when provider is set
+ # TODO - Clean up robust null checks
+ if "provider" in default_config[section] and (
+ default_config[section]["provider"] is not None
+ and default_config[section]["provider"] != "None"
+ and default_config[section]["provider"] != "null"
+ ):
+ self._validate_config_section(default_config, section, keys)
+ setattr(self, section, default_config[section])
+
+ self.app = self.app # for type hinting
+ self.ingestion = self.ingestion # for type hinting
+ self.ingestion["excluded_parsers"] = [
+ DocumentType(k) for k in self.ingestion["excluded_parsers"]
+ ]
+ # override GenerationConfig defaults
+ GenerationConfig.set_default(
+ **self.completions.get("generation_config", {})
+ )
+ self.embedding = EmbeddingConfig.create(**self.embedding)
+ self.kg = KGConfig.create(**self.kg)
+ eval_llm = self.eval.pop("llm", None)
+ self.eval = EvalConfig.create(
+ **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None
+ )
+ self.completions = LLMConfig.create(**self.completions)
+ self.logging = LoggingConfig.create(**self.logging)
+ self.prompt = PromptConfig.create(**self.prompt)
+ self.vector_database = VectorDBConfig.create(**self.vector_database)
+
+ def _validate_config_section(
+ self, config_data: dict[str, Any], section: str, keys: list
+ ):
+ if section not in config_data:
+ raise ValueError(f"Missing '{section}' section in config")
+ if not all(key in config_data[section] for key in keys):
+ raise ValueError(f"Missing required keys in '{section}' config")
+
+ @classmethod
+ def from_json(cls, config_path: str = None) -> "R2RConfig":
+ if config_path is None:
+ # Get the root directory of the project
+ file_dir = os.path.dirname(os.path.abspath(__file__))
+ config_path = os.path.join(
+ file_dir, "..", "..", "..", "config.json"
+ )
+
+ # Load configuration from JSON file
+ with open(config_path) as f:
+ config_data = json.load(f)
+
+ return cls(config_data)
+
+ def to_json(self):
+ config_data = {
+ section: self._serialize_config(getattr(self, section))
+ for section in R2RConfig.REQUIRED_KEYS.keys()
+ }
+ return json.dumps(config_data)
+
+ def save_to_redis(self, redis_client: Any, key: str):
+ redis_client.set(f"R2RConfig:{key}", self.to_json())
+
+ @classmethod
+ def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig":
+ config_data = redis_client.get(f"R2RConfig:{key}")
+ if config_data is None:
+ raise ValueError(
+ f"Configuration not found in Redis with key '{key}'"
+ )
+ config_data = json.loads(config_data)
+ # config_data["ingestion"]["selected_parsers"] = {
+ # DocumentType(k): v
+ # for k, v in config_data["ingestion"]["selected_parsers"].items()
+ # }
+ return cls(config_data)
+
+ @classmethod
+ def load_default_config(cls) -> dict:
+ # Get the root directory of the project
+ file_dir = os.path.dirname(os.path.abspath(__file__))
+ default_config_path = os.path.join(
+ file_dir, "..", "..", "..", "config.json"
+ )
+ # Load default configuration from JSON file
+ with open(default_config_path) as f:
+ return json.load(f)
+
+ @staticmethod
+ def _serialize_config(config_section: Any) -> dict:
+ # TODO - Make this approach cleaner
+ if isinstance(config_section, ProviderConfig):
+ config_section = config_section.dict()
+ filtered_result = {}
+ for k, v in config_section.items():
+ if isinstance(k, Enum):
+ k = k.value
+ if isinstance(v, dict):
+ formatted_v = {
+ k2.value if isinstance(k2, Enum) else k2: v2
+ for k2, v2 in v.items()
+ }
+ v = formatted_v
+ filtered_result[k] = v
+ return filtered_result