about summary refs log tree commit diff
path: root/R2R/r2r/base/providers/base_provider.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/base_provider.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/base/providers/base_provider.py')
-rwxr-xr-xR2R/r2r/base/providers/base_provider.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/base_provider.py b/R2R/r2r/base/providers/base_provider.py
new file mode 100755
index 00000000..8ee8d56a
--- /dev/null
+++ b/R2R/r2r/base/providers/base_provider.py
@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod, abstractproperty
+from typing import Any, Optional, Type
+
+from pydantic import BaseModel
+
+
+class ProviderConfig(BaseModel, ABC):
+    """A base provider configuration class"""
+
+    extra_fields: dict[str, Any] = {}
+    provider: Optional[str] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+        ignore_extra = True
+
+    @abstractmethod
+    def validate(self) -> None:
+        pass
+
+    @classmethod
+    def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
+        base_args = cls.__fields__.keys()
+        filtered_kwargs = {
+            k: v if v != "None" else None
+            for k, v in kwargs.items()
+            if k in base_args
+        }
+        instance = cls(**filtered_kwargs)
+        for k, v in kwargs.items():
+            if k not in base_args:
+                instance.extra_fields[k] = v
+        return instance
+
+    @abstractproperty
+    @property
+    def supported_providers(self) -> list[str]:
+        """Define a list of supported providers."""
+        pass
+
+
+class Provider(ABC):
+    """A base provider class to provide a common interface for all providers."""
+
+    def __init__(self, config: Optional[ProviderConfig] = None):
+        if config:
+            config.validate()
+        self.config = config