1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
|
import json
import logging
import os
from enum import Enum
from typing import Any
from ...base.abstractions.document import DocumentType
from ...base.abstractions.llm import GenerationConfig
from ...base.logging.kv_logger import LoggingConfig
from ...base.providers.embedding_provider import EmbeddingConfig
from ...base.providers.eval_provider import EvalConfig
from ...base.providers.kg_provider import KGConfig
from ...base.providers.llm_provider import LLMConfig
from ...base.providers.prompt_provider import PromptConfig
from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig
logger = logging.getLogger(__name__)
class R2RConfig:
REQUIRED_KEYS: dict[str, list] = {
"app": ["max_file_size_in_mb"],
"embedding": [
"provider",
"base_model",
"base_dimension",
"batch_size",
"text_splitter",
],
"eval": ["llm"],
"kg": [
"provider",
"batch_size",
"kg_extraction_config",
"text_splitter",
],
"ingestion": ["excluded_parsers"],
"completions": ["provider"],
"logging": ["provider", "log_table"],
"prompt": ["provider"],
"vector_database": ["provider"],
}
app: dict[str, Any]
embedding: EmbeddingConfig
completions: LLMConfig
logging: LoggingConfig
prompt: PromptConfig
vector_database: VectorDBConfig
def __init__(self, config_data: dict[str, Any]):
# Load the default configuration
default_config = self.load_default_config()
# Override the default configuration with the passed configuration
for key in config_data:
if key in default_config:
default_config[key].update(config_data[key])
else:
default_config[key] = config_data[key]
# Validate and set the configuration
for section, keys in R2RConfig.REQUIRED_KEYS.items():
# Check the keys when provider is set
# TODO - Clean up robust null checks
if "provider" in default_config[section] and (
default_config[section]["provider"] is not None
and default_config[section]["provider"] != "None"
and default_config[section]["provider"] != "null"
):
self._validate_config_section(default_config, section, keys)
setattr(self, section, default_config[section])
self.app = self.app # for type hinting
self.ingestion = self.ingestion # for type hinting
self.ingestion["excluded_parsers"] = [
DocumentType(k) for k in self.ingestion["excluded_parsers"]
]
# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completions.get("generation_config", {})
)
self.embedding = EmbeddingConfig.create(**self.embedding)
self.kg = KGConfig.create(**self.kg)
eval_llm = self.eval.pop("llm", None)
self.eval = EvalConfig.create(
**self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None
)
self.completions = LLMConfig.create(**self.completions)
self.logging = LoggingConfig.create(**self.logging)
self.prompt = PromptConfig.create(**self.prompt)
self.vector_database = VectorDBConfig.create(**self.vector_database)
def _validate_config_section(
self, config_data: dict[str, Any], section: str, keys: list
):
if section not in config_data:
raise ValueError(f"Missing '{section}' section in config")
if not all(key in config_data[section] for key in keys):
raise ValueError(f"Missing required keys in '{section}' config")
@classmethod
def from_json(cls, config_path: str = None) -> "R2RConfig":
if config_path is None:
# Get the root directory of the project
file_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(
file_dir, "..", "..", "..", "config.json"
)
# Load configuration from JSON file
with open(config_path) as f:
config_data = json.load(f)
return cls(config_data)
def to_json(self):
config_data = {
section: self._serialize_config(getattr(self, section))
for section in R2RConfig.REQUIRED_KEYS.keys()
}
return json.dumps(config_data)
def save_to_redis(self, redis_client: Any, key: str):
redis_client.set(f"R2RConfig:{key}", self.to_json())
@classmethod
def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig":
config_data = redis_client.get(f"R2RConfig:{key}")
if config_data is None:
raise ValueError(
f"Configuration not found in Redis with key '{key}'"
)
config_data = json.loads(config_data)
# config_data["ingestion"]["selected_parsers"] = {
# DocumentType(k): v
# for k, v in config_data["ingestion"]["selected_parsers"].items()
# }
return cls(config_data)
@classmethod
def load_default_config(cls) -> dict:
# Get the root directory of the project
file_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(
file_dir, "..", "..", "..", "config.json"
)
# Load default configuration from JSON file
with open(default_config_path) as f:
return json.load(f)
@staticmethod
def _serialize_config(config_section: Any) -> dict:
# TODO - Make this approach cleaner
if isinstance(config_section, ProviderConfig):
config_section = config_section.dict()
filtered_result = {}
for k, v in config_section.items():
if isinstance(k, Enum):
k = k.value
if isinstance(v, dict):
formatted_v = {
k2.value if isinstance(k2, Enum) else k2: v2
for k2, v2 in v.items()
}
v = formatted_v
filtered_result[k] = v
return filtered_result
|