aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/providers/base_provider.py
blob: 8ee8d56a73ba540fb8929a5085a524ac96c8fda5 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Optional, Type

from pydantic import BaseModel


class ProviderConfig(BaseModel, ABC):
    """A base provider configuration class"""

    extra_fields: dict[str, Any] = {}
    provider: Optional[str] = None

    class Config:
        arbitrary_types_allowed = True
        ignore_extra = True

    @abstractmethod
    def validate(self) -> None:
        pass

    @classmethod
    def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
        base_args = cls.__fields__.keys()
        filtered_kwargs = {
            k: v if v != "None" else None
            for k, v in kwargs.items()
            if k in base_args
        }
        instance = cls(**filtered_kwargs)
        for k, v in kwargs.items():
            if k not in base_args:
                instance.extra_fields[k] = v
        return instance

    @abstractproperty
    @property
    def supported_providers(self) -> list[str]:
        """Define a list of supported providers."""
        pass


class Provider(ABC):
    """A base provider class to provide a common interface for all providers."""

    def __init__(self, config: Optional[ProviderConfig] = None):
        if config:
            config.validate()
        self.config = config