aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/main/config.py
blob: f49b4041c130bf5e32a4f6d6ea996854eb73480e (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
import logging
import os
from enum import Enum
from typing import Any, Optional

import toml
from pydantic import BaseModel

from ..base.abstractions import GenerationConfig
from ..base.agent.agent import RAGAgentConfig  # type: ignore
from ..base.providers import AppConfig
from ..base.providers.auth import AuthConfig
from ..base.providers.crypto import CryptoConfig
from ..base.providers.database import DatabaseConfig
from ..base.providers.email import EmailConfig
from ..base.providers.embedding import EmbeddingConfig
from ..base.providers.ingestion import IngestionConfig
from ..base.providers.llm import CompletionConfig
from ..base.providers.orchestration import OrchestrationConfig
from ..base.utils import deep_update

logger = logging.getLogger()


class R2RConfig:
    current_file_path = os.path.dirname(__file__)
    config_dir_root = os.path.join(current_file_path, "..", "configs")
    default_config_path = os.path.join(
        current_file_path, "..", "..", "r2r", "r2r.toml"
    )

    CONFIG_OPTIONS: dict[str, Optional[str]] = {}
    for file_ in os.listdir(config_dir_root):
        if file_.endswith(".toml"):
            CONFIG_OPTIONS[file_.removesuffix(".toml")] = os.path.join(
                config_dir_root, file_
            )
    CONFIG_OPTIONS["default"] = None

    REQUIRED_KEYS: dict[str, list] = {
        "app": [],
        "completion": ["provider"],
        "crypto": ["provider"],
        "email": ["provider"],
        "auth": ["provider"],
        "embedding": [
            "provider",
            "base_model",
            "base_dimension",
            "batch_size",
            "add_title_as_prefix",
        ],
        "completion_embedding": [
            "provider",
            "base_model",
            "base_dimension",
            "batch_size",
            "add_title_as_prefix",
        ],
        # TODO - deprecated, remove
        "ingestion": ["provider"],
        "logging": ["provider", "log_table"],
        "database": ["provider"],
        "agent": ["generation_config"],
        "orchestration": ["provider"],
    }

    app: AppConfig
    auth: AuthConfig
    completion: CompletionConfig
    crypto: CryptoConfig
    database: DatabaseConfig
    embedding: EmbeddingConfig
    completion_embedding: EmbeddingConfig
    email: EmailConfig
    ingestion: IngestionConfig
    agent: RAGAgentConfig
    orchestration: OrchestrationConfig

    def __init__(self, config_data: dict[str, Any]):
        """
        :param config_data: dictionary of configuration parameters
        :param base_path: base path when a relative path is specified for the prompts directory
        """
        # Load the default configuration
        default_config = self.load_default_config()

        # Override the default configuration with the passed configuration
        default_config = deep_update(default_config, config_data)

        # Validate and set the configuration
        for section, keys in R2RConfig.REQUIRED_KEYS.items():
            # Check the keys when provider is set
            # TODO - remove after deprecation
            if section in ["graph", "file"] and section not in default_config:
                continue
            if "provider" in default_config[section] and (
                default_config[section]["provider"] is not None
                and default_config[section]["provider"] != "None"
                and default_config[section]["provider"] != "null"
            ):
                self._validate_config_section(default_config, section, keys)
            setattr(self, section, default_config[section])

        self.app = AppConfig.create(**self.app)  # type: ignore
        self.auth = AuthConfig.create(**self.auth, app=self.app)  # type: ignore
        self.completion = CompletionConfig.create(
            **self.completion, app=self.app
        )  # type: ignore
        self.crypto = CryptoConfig.create(**self.crypto, app=self.app)  # type: ignore
        self.email = EmailConfig.create(**self.email, app=self.app)  # type: ignore
        self.database = DatabaseConfig.create(**self.database, app=self.app)  # type: ignore
        self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app)  # type: ignore
        self.completion_embedding = EmbeddingConfig.create(
            **self.completion_embedding, app=self.app
        )  # type: ignore
        self.ingestion = IngestionConfig.create(**self.ingestion, app=self.app)  # type: ignore
        self.agent = RAGAgentConfig.create(**self.agent, app=self.app)  # type: ignore
        self.orchestration = OrchestrationConfig.create(
            **self.orchestration, app=self.app
        )  # type: ignore

        IngestionConfig.set_default(**self.ingestion.dict())

        # override GenerationConfig defaults
        if self.completion.generation_config:
            GenerationConfig.set_default(
                **self.completion.generation_config.dict()
            )

    def _validate_config_section(
        self, config_data: dict[str, Any], section: str, keys: list
    ):
        if section not in config_data:
            raise ValueError(f"Missing '{section}' section in config")
        if missing_keys := [
            key for key in keys if key not in config_data[section]
        ]:
            raise ValueError(
                f"Missing required keys in '{section}' config: {', '.join(missing_keys)}"
            )

    @classmethod
    def from_toml(cls, config_path: Optional[str] = None) -> "R2RConfig":
        if config_path is None:
            config_path = R2RConfig.default_config_path

        # Load configuration from TOML file
        with open(config_path, encoding="utf-8") as f:
            config_data = toml.load(f)

        return cls(config_data)

    def to_toml(self):
        config_data = {}
        for section in R2RConfig.REQUIRED_KEYS.keys():
            section_data = self._serialize_config(getattr(self, section))
            if isinstance(section_data, dict):
                # Remove app from nested configs before serializing
                section_data.pop("app", None)
            config_data[section] = section_data
        return toml.dumps(config_data)

    @classmethod
    def load_default_config(cls) -> dict:
        with open(R2RConfig.default_config_path, encoding="utf-8") as f:
            return toml.load(f)

    @staticmethod
    def _serialize_config(config_section: Any):
        """Serialize config section while excluding internal state."""
        if isinstance(config_section, dict):
            return {
                R2RConfig._serialize_key(k): R2RConfig._serialize_config(v)
                for k, v in config_section.items()
                if k != "app"  # Exclude app from serialization
            }
        elif isinstance(config_section, (list, tuple)):
            return [
                R2RConfig._serialize_config(item) for item in config_section
            ]
        elif isinstance(config_section, Enum):
            return config_section.value
        elif isinstance(config_section, BaseModel):
            data = config_section.model_dump(exclude_none=True)
            data.pop("app", None)  # Remove app from the serialized data
            return R2RConfig._serialize_config(data)
        else:
            return config_section

    @staticmethod
    def _serialize_key(key: Any) -> str:
        return key.value if isinstance(key, Enum) else str(key)

    @classmethod
    def load(
        cls,
        config_name: Optional[str] = None,
        config_path: Optional[str] = None,
    ) -> "R2RConfig":
        if config_path and config_name:
            raise ValueError(
                f"Cannot specify both config_path and config_name. Got: {config_path}, {config_name}"
            )

        if config_path := os.getenv("R2R_CONFIG_PATH") or config_path:
            return cls.from_toml(config_path)

        config_name = os.getenv("R2R_CONFIG_NAME") or config_name or "default"
        if config_name not in R2RConfig.CONFIG_OPTIONS:
            raise ValueError(f"Invalid config name: {config_name}")
        return cls.from_toml(R2RConfig.CONFIG_OPTIONS[config_name])