1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
|
"""Base classes for language model providers."""
import logging
from abc import abstractmethod
from typing import Optional
from r2r.base.abstractions.llm import GenerationConfig
from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk
from .base_provider import Provider, ProviderConfig
logger = logging.getLogger(__name__)
class LLMConfig(ProviderConfig):
"""A base LLM config class"""
provider: Optional[str] = None
generation_config: Optional[GenerationConfig] = None
def validate(self) -> None:
if not self.provider:
raise ValueError("Provider must be set.")
if self.provider and self.provider not in self.supported_providers:
raise ValueError(f"Provider '{self.provider}' is not supported.")
@property
def supported_providers(self) -> list[str]:
return ["litellm", "openai"]
class LLMProvider(Provider):
"""An abstract class to provide a common interface for LLMs."""
def __init__(
self,
config: LLMConfig,
) -> None:
if not isinstance(config, LLMConfig):
raise ValueError(
"LLMProvider must be initialized with a `LLMConfig`."
)
logger.info(f"Initializing LLM provider with config: {config}")
super().__init__(config)
@abstractmethod
def get_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> LLMChatCompletion:
"""Abstract method to get a chat completion from the provider."""
pass
@abstractmethod
def get_completion_stream(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> LLMChatCompletionChunk:
"""Abstract method to get a completion stream from the provider."""
pass
|