diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/providers/llm_provider.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to 'R2R/r2r/base/providers/llm_provider.py')
-rwxr-xr-x | R2R/r2r/base/providers/llm_provider.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/llm_provider.py b/R2R/r2r/base/providers/llm_provider.py new file mode 100755 index 00000000..9b6499a4 --- /dev/null +++ b/R2R/r2r/base/providers/llm_provider.py @@ -0,0 +1,66 @@ +"""Base classes for language model providers.""" + +import logging +from abc import abstractmethod +from typing import Optional + +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class LLMConfig(ProviderConfig): + """A base LLM config class""" + + provider: Optional[str] = None + generation_config: Optional[GenerationConfig] = None + + def validate(self) -> None: + if not self.provider: + raise ValueError("Provider must be set.") + + if self.provider and self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["litellm", "openai"] + + +class LLMProvider(Provider): + """An abstract class to provide a common interface for LLMs.""" + + def __init__( + self, + config: LLMConfig, + ) -> None: + if not isinstance(config, LLMConfig): + raise ValueError( + "LLMProvider must be initialized with a `LLMConfig`." + ) + logger.info(f"Initializing LLM provider with config: {config}") + + super().__init__(config) + + @abstractmethod + def get_completion( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletion: + """Abstract method to get a chat completion from the provider.""" + pass + + @abstractmethod + def get_completion_stream( + self, + messages: list[dict], + generation_config: GenerationConfig, + **kwargs, + ) -> LLMChatCompletionChunk: + """Abstract method to get a completion stream from the provider.""" + pass |