aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/providers/llm_provider.py
blob: 9b6499a4b57ed68dbaf8d18ed97d819e02512c02 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Base classes for language model providers."""

import logging
from abc import abstractmethod
from typing import Optional

from r2r.base.abstractions.llm import GenerationConfig

from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk
from .base_provider import Provider, ProviderConfig

logger = logging.getLogger(__name__)


class LLMConfig(ProviderConfig):
    """A base LLM config class"""

    provider: Optional[str] = None
    generation_config: Optional[GenerationConfig] = None

    def validate(self) -> None:
        if not self.provider:
            raise ValueError("Provider must be set.")

        if self.provider and self.provider not in self.supported_providers:
            raise ValueError(f"Provider '{self.provider}' is not supported.")

    @property
    def supported_providers(self) -> list[str]:
        return ["litellm", "openai"]


class LLMProvider(Provider):
    """An abstract class to provide a common interface for LLMs."""

    def __init__(
        self,
        config: LLMConfig,
    ) -> None:
        if not isinstance(config, LLMConfig):
            raise ValueError(
                "LLMProvider must be initialized with a `LLMConfig`."
            )
        logger.info(f"Initializing LLM provider with config: {config}")

        super().__init__(config)

    @abstractmethod
    def get_completion(
        self,
        messages: list[dict],
        generation_config: GenerationConfig,
        **kwargs,
    ) -> LLMChatCompletion:
        """Abstract method to get a chat completion from the provider."""
        pass

    @abstractmethod
    def get_completion_stream(
        self,
        messages: list[dict],
        generation_config: GenerationConfig,
        **kwargs,
    ) -> LLMChatCompletionChunk:
        """Abstract method to get a completion stream from the provider."""
        pass