aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/llms
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/llms')
-rwxr-xr-xR2R/r2r/providers/llms/__init__.py7
-rwxr-xr-xR2R/r2r/providers/llms/litellm/base_litellm.py142
-rwxr-xr-xR2R/r2r/providers/llms/openai/base_openai.py144
3 files changed, 293 insertions, 0 deletions
diff --git a/R2R/r2r/providers/llms/__init__.py b/R2R/r2r/providers/llms/__init__.py
new file mode 100755
index 00000000..38a1c54a
--- /dev/null
+++ b/R2R/r2r/providers/llms/__init__.py
@@ -0,0 +1,7 @@
+from .litellm.base_litellm import LiteLLM
+from .openai.base_openai import OpenAILLM
+
+__all__ = [
+ "LiteLLM",
+ "OpenAILLM",
+]
diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py
new file mode 100755
index 00000000..581cce9a
--- /dev/null
+++ b/R2R/r2r/providers/llms/litellm/base_litellm.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Any, Generator, Union
+
+from r2r.base import (
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ LLMConfig,
+ LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LiteLLM(LLMProvider):
+ """A concrete class for creating LiteLLM models."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ from litellm import acompletion, completion
+
+ self.litellm_completion = completion
+ self.litellm_acompletion = acompletion
+ except ImportError:
+ raise ImportError(
+ "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`."
+ )
+ super().__init__(config)
+
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `get_completion` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ if not generation_config.stream:
+ raise ValueError(
+ "Stream must be set to True to use the `get_completion_stream` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def extract_content(self, response: LLMChatCompletion) -> str:
+ return response.choices[0].message.content
+
+ def _get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[
+ LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None]
+ ]:
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ response = self.litellm_completion(**args)
+
+ if not generation_config.stream:
+ return LLMChatCompletion(**response.dict())
+ else:
+ return self._get_chat_completion(response)
+
+ def _get_chat_completion(
+ self,
+ response: Any,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ for part in response:
+ yield LLMChatCompletionChunk(**part.dict())
+
+ def _get_base_args(
+ self,
+ generation_config: GenerationConfig,
+ prompt=None,
+ ) -> dict:
+ """Get the base arguments for the LiteLLM API."""
+ args = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+ "max_tokens": generation_config.max_tokens_to_sample,
+ }
+ return args
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `aget_completion` method."
+ )
+ return await self._aget_completion(
+ messages, generation_config, **kwargs
+ )
+
+ async def _aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return await self.litellm_acompletion(**args)
diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py
new file mode 100755
index 00000000..460c0f0b
--- /dev/null
+++ b/R2R/r2r/providers/llms/openai/base_openai.py
@@ -0,0 +1,144 @@
+"""A module for creating OpenAI model abstractions."""
+
+import logging
+import os
+from typing import Union
+
+from r2r.base import (
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ LLMConfig,
+ LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAILLM(LLMProvider):
+ """A concrete class for creating OpenAI models."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ if not isinstance(config, LLMConfig):
+ raise ValueError(
+ "The provided config must be an instance of OpenAIConfig."
+ )
+ try:
+ from openai import OpenAI # noqa
+ except ImportError:
+ raise ImportError(
+ "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
+ )
+ if config.provider != "openai":
+ raise ValueError(
+ "OpenAILLM must be initialized with config with `openai` provider."
+ )
+ if not os.getenv("OPENAI_API_KEY"):
+ raise ValueError(
+ "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
+ )
+ super().__init__(config)
+ self.config: LLMConfig = config
+ self.client = OpenAI()
+
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `get_completion` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletionChunk:
+ if not generation_config.stream:
+ raise ValueError(
+ "Stream must be set to True to use the `get_completion_stream` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def _get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return self.client.chat.completions.create(**args)
+
+ def _get_base_args(
+ self,
+ generation_config: GenerationConfig,
+ ) -> dict:
+ """Get the base arguments for the OpenAI API."""
+
+ args = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+ "max_tokens": generation_config.max_tokens_to_sample,
+ }
+
+ return args
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `aget_completion` method."
+ )
+ return await self._aget_completion(
+ messages, generation_config, **kwargs
+ )
+
+ async def _aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return await self.client.chat.completions.create(**args)