diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/main/config.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/main/config.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/main/config.py | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/main/config.py b/.venv/lib/python3.12/site-packages/core/main/config.py new file mode 100644 index 00000000..f49b4041 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/main/config.py @@ -0,0 +1,213 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +import logging +import os +from enum import Enum +from typing import Any, Optional + +import toml +from pydantic import BaseModel + +from ..base.abstractions import GenerationConfig +from ..base.agent.agent import RAGAgentConfig # type: ignore +from ..base.providers import AppConfig +from ..base.providers.auth import AuthConfig +from ..base.providers.crypto import CryptoConfig +from ..base.providers.database import DatabaseConfig +from ..base.providers.email import EmailConfig +from ..base.providers.embedding import EmbeddingConfig +from ..base.providers.ingestion import IngestionConfig +from ..base.providers.llm import CompletionConfig +from ..base.providers.orchestration import OrchestrationConfig +from ..base.utils import deep_update + +logger = logging.getLogger() + + +class R2RConfig: + current_file_path = os.path.dirname(__file__) + config_dir_root = os.path.join(current_file_path, "..", "configs") + default_config_path = os.path.join( + current_file_path, "..", "..", "r2r", "r2r.toml" + ) + + CONFIG_OPTIONS: dict[str, Optional[str]] = {} + for file_ in os.listdir(config_dir_root): + if file_.endswith(".toml"): + CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join( + config_dir_root, file_ + ) + CONFIG_OPTIONS["default"] = None + + REQUIRED_KEYS: dict[str, list] = { + "app": [], + "completion": ["provider"], + "crypto": ["provider"], + "email": ["provider"], + "auth": ["provider"], + "embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "add_title_as_prefix", + ], + "completion_embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "add_title_as_prefix", + ], + # TODO - deprecated, remove + "ingestion": ["provider"], + "logging": ["provider", "log_table"], + "database": ["provider"], + "agent": ["generation_config"], + "orchestration": ["provider"], + } + + app: AppConfig + auth: AuthConfig + completion: CompletionConfig + crypto: CryptoConfig + database: DatabaseConfig + embedding: EmbeddingConfig + completion_embedding: EmbeddingConfig + email: EmailConfig + ingestion: IngestionConfig + agent: RAGAgentConfig + orchestration: OrchestrationConfig + + def __init__(self, config_data: dict[str, Any]): + """ + :param config_data: dictionary of configuration parameters + :param base_path: base path when a relative path is specified for the prompts directory + """ + # Load the default configuration + default_config = self.load_default_config() + + # Override the default configuration with the passed configuration + default_config = deep_update(default_config, config_data) + + # Validate and set the configuration + for section, keys in R2RConfig.REQUIRED_KEYS.items(): + # Check the keys when provider is set + # TODO - remove after deprecation + if section in ["graph", "file"] and section not in default_config: + continue + if "provider" in default_config[section] and ( + default_config[section]["provider"] is not None + and default_config[section]["provider"] != "None" + and default_config[section]["provider"] != "null" + ): + self._validate_config_section(default_config, section, keys) + setattr(self, section, default_config[section]) + + self.app = AppConfig.create(**self.app) # type: ignore + self.auth = AuthConfig.create(**self.auth, app=self.app) # type: ignore + self.completion = CompletionConfig.create( + **self.completion, app=self.app + ) # type: ignore + self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore + self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore + self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore + self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore + self.completion_embedding = EmbeddingConfig.create( + **self.completion_embedding, app=self.app + ) # type: ignore + self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app) # type: ignore + self.agent = RAGAgentConfig.create(**self.agent, app=self.app) # type: ignore + self.orchestration = OrchestrationConfig.create( + **self.orchestration, app=self.app + ) # type: ignore + + IngestionConfig.set_default(**self.ingestion.dict()) + + # override GenerationConfig defaults + if self.completion.generation_config: + GenerationConfig.set_default( + **self.completion.generation_config.dict() + ) + + def _validate_config_section( + self, config_data: dict[str, Any], section: str, keys: list + ): + if section not in config_data: + raise ValueError(f"Missing '{section}' section in config") + if missing_keys := [ + key for key in keys if key not in config_data[section] + ]: + raise ValueError( + f"Missing required keys in '{section}' config: {', '.join(missing_keys)}" + ) + + @classmethod + def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig": + if config_path is None: + config_path = R2RConfig.default_config_path + + # Load configuration from TOML file + with open(config_path, encoding="utf-8") as f: + config_data = toml.load(f) + + return cls(config_data) + + def to_toml(self): + config_data = {} + for section in R2RConfig.REQUIRED_KEYS.keys(): + section_data = self._serialize_config(getattr(self, section)) + if isinstance(section_data, dict): + # Remove app from nested configs before serializing + section_data.pop("app", None) + config_data[section] = section_data + return toml.dumps(config_data) + + @classmethod + def load_default_config(cls) -> dict: + with open(R2RConfig.default_config_path, encoding="utf-8") as f: + return toml.load(f) + + @staticmethod + def _serialize_config(config_section: Any): + """Serialize config section while excluding internal state.""" + if isinstance(config_section, dict): + return { + R2RConfig._serialize_key(k): R2RConfig._serialize_config(v) + for k, v in config_section.items() + if k != "app" # Exclude app from serialization + } + elif isinstance(config_section, (list, tuple)): + return [ + R2RConfig._serialize_config(item) for item in config_section + ] + elif isinstance(config_section, Enum): + return config_section.value + elif isinstance(config_section, BaseModel): + data = config_section.model_dump(exclude_none=True) + data.pop("app", None) # Remove app from the serialized data + return R2RConfig._serialize_config(data) + else: + return config_section + + @staticmethod + def _serialize_key(key: Any) -> str: + return key.value if isinstance(key, Enum) else str(key) + + @classmethod + def load( + cls, + config_name: Optional[str] = None, + config_path: Optional[str] = None, + ) -> "R2RConfig": + if config_path and config_name: + raise ValueError( + f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}" + ) + + if config_path := os.getenv("R2R_CONFIG_PATH") or config_path: + return cls.from_toml(config_path) + + config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default" + if config_name not in R2RConfig.CONFIG_OPTIONS: + raise ValueError(f"Invalid config name: {config_name}") + return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name]) |