diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/base_provider.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/base/providers/base_provider.py')
-rwxr-xr-x | R2R/r2r/base/providers/base_provider.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py new file mode 100755 index 00000000..8ee8d56a --- /dev/null +++ b/R2R/r2r/base/providers/base_provider.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, Optional, Type + +from pydantic import BaseModel + + +class ProviderConfig(BaseModel, ABC): + """A base provider configuration class""" + + extra_fields: dict[str, Any] = {} + provider: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + ignore_extra = True + + @abstractmethod + def validate(self) -> None: + pass + + @classmethod + def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + base_args = cls.__fields__.keys() + filtered_kwargs = { + k: v if v != "None" else None + for k, v in kwargs.items() + if k in base_args + } + instance = cls(**filtered_kwargs) + for k, v in kwargs.items(): + if k not in base_args: + instance.extra_fields[k] = v + return instance + + @abstractproperty + @property + def supported_providers(self) -> list[str]: + """Define a list of supported providers.""" + pass + + +class Provider(ABC): + """A base provider class to provide a common interface for all providers.""" + + def __init__(self, config: Optional[ProviderConfig] = None): + if config: + config.validate() + self.config = config |