about summary refs log tree commit diff
path: root/R2R/r2r/providers/llms
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/llms')
-rwxr-xr-xR2R/r2r/providers/llms/__init__.py7
-rwxr-xr-xR2R/r2r/providers/llms/litellm/base_litellm.py142
-rwxr-xr-xR2R/r2r/providers/llms/openai/base_openai.py144
3 files changed, 293 insertions, 0 deletions
diff --git a/R2R/r2r/providers/llms/__init__.py b/R2R/r2r/providers/llms/__init__.py
new file mode 100755
index 00000000..38a1c54a
--- /dev/null
+++ b/R2R/r2r/providers/llms/__init__.py
@@ -0,0 +1,7 @@
+from .litellm.base_litellm import LiteLLM
+from .openai.base_openai import OpenAILLM
+
+__all__ = [
+    "LiteLLM",
+    "OpenAILLM",
+]
diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py
new file mode 100755
index 00000000..581cce9a
--- /dev/null
+++ b/R2R/r2r/providers/llms/litellm/base_litellm.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Any, Generator, Union
+
+from r2r.base import (
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    LLMConfig,
+    LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LiteLLM(LLMProvider):
+    """A concrete class for creating LiteLLM models."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+        *args,
+        **kwargs,
+    ) -> None:
+        try:
+            from litellm import acompletion, completion
+
+            self.litellm_completion = completion
+            self.litellm_acompletion = acompletion
+        except ImportError:
+            raise ImportError(
+                "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`."
+            )
+        super().__init__(config)
+
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `get_completion` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        if not generation_config.stream:
+            raise ValueError(
+                "Stream must be set to True to use the `get_completion_stream` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def extract_content(self, response: LLMChatCompletion) -> str:
+        return response.choices[0].message.content
+
+    def _get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[
+        LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None]
+    ]:
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        response = self.litellm_completion(**args)
+
+        if not generation_config.stream:
+            return LLMChatCompletion(**response.dict())
+        else:
+            return self._get_chat_completion(response)
+
+    def _get_chat_completion(
+        self,
+        response: Any,
+    ) -> Generator[LLMChatCompletionChunk, None, None]:
+        for part in response:
+            yield LLMChatCompletionChunk(**part.dict())
+
+    def _get_base_args(
+        self,
+        generation_config: GenerationConfig,
+        prompt=None,
+    ) -> dict:
+        """Get the base arguments for the LiteLLM API."""
+        args = {
+            "model": generation_config.model,
+            "temperature": generation_config.temperature,
+            "top_p": generation_config.top_p,
+            "stream": generation_config.stream,
+            # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+            "max_tokens": generation_config.max_tokens_to_sample,
+        }
+        return args
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `aget_completion` method."
+            )
+        return await self._aget_completion(
+            messages, generation_config, **kwargs
+        )
+
+    async def _aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return await self.litellm_acompletion(**args)
diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py
new file mode 100755
index 00000000..460c0f0b
--- /dev/null
+++ b/R2R/r2r/providers/llms/openai/base_openai.py
@@ -0,0 +1,144 @@
+"""A module for creating OpenAI model abstractions."""
+
+import logging
+import os
+from typing import Union
+
+from r2r.base import (
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    LLMConfig,
+    LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAILLM(LLMProvider):
+    """A concrete class for creating OpenAI models."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+        *args,
+        **kwargs,
+    ) -> None:
+        if not isinstance(config, LLMConfig):
+            raise ValueError(
+                "The provided config must be an instance of OpenAIConfig."
+            )
+        try:
+            from openai import OpenAI  # noqa
+        except ImportError:
+            raise ImportError(
+                "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
+            )
+        if config.provider != "openai":
+            raise ValueError(
+                "OpenAILLM must be initialized with config with `openai` provider."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
+            )
+        super().__init__(config)
+        self.config: LLMConfig = config
+        self.client = OpenAI()
+
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `get_completion` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletionChunk:
+        if not generation_config.stream:
+            raise ValueError(
+                "Stream must be set to True to use the `get_completion_stream` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def _get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return self.client.chat.completions.create(**args)
+
+    def _get_base_args(
+        self,
+        generation_config: GenerationConfig,
+    ) -> dict:
+        """Get the base arguments for the OpenAI API."""
+
+        args = {
+            "model": generation_config.model,
+            "temperature": generation_config.temperature,
+            "top_p": generation_config.top_p,
+            "stream": generation_config.stream,
+            # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+            "max_tokens": generation_config.max_tokens_to_sample,
+        }
+
+        return args
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `aget_completion` method."
+            )
+        return await self._aget_completion(
+            messages, generation_config, **kwargs
+        )
+
+    async def _aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return await self.client.chat.completions.create(**args)