aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/llms/litellm/base_litellm.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/llms/litellm/base_litellm.py')
-rwxr-xr-xR2R/r2r/providers/llms/litellm/base_litellm.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/R2R/r2r/providers/llms/litellm/base_litellm.py b/R2R/r2r/providers/llms/litellm/base_litellm.py
new file mode 100755
index 00000000..581cce9a
--- /dev/null
+++ b/R2R/r2r/providers/llms/litellm/base_litellm.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Any, Generator, Union
+
+from r2r.base import (
+ LLMChatCompletion,
+ LLMChatCompletionChunk,
+ LLMConfig,
+ LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class LiteLLM(LLMProvider):
+ """A concrete class for creating LiteLLM models."""
+
+ def __init__(
+ self,
+ config: LLMConfig,
+ *args,
+ **kwargs,
+ ) -> None:
+ try:
+ from litellm import acompletion, completion
+
+ self.litellm_completion = completion
+ self.litellm_acompletion = acompletion
+ except ImportError:
+ raise ImportError(
+ "Error, `litellm` is required to run a LiteLLM. Please install it using `pip install litellm`."
+ )
+ super().__init__(config)
+
+ def get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `get_completion` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def get_completion_stream(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ if not generation_config.stream:
+ raise ValueError(
+ "Stream must be set to True to use the `get_completion_stream` method."
+ )
+ return self._get_completion(messages, generation_config, **kwargs)
+
+ def extract_content(self, response: LLMChatCompletion) -> str:
+ return response.choices[0].message.content
+
+ def _get_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[
+ LLMChatCompletion, Generator[LLMChatCompletionChunk, None, None]
+ ]:
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ response = self.litellm_completion(**args)
+
+ if not generation_config.stream:
+ return LLMChatCompletion(**response.dict())
+ else:
+ return self._get_chat_completion(response)
+
+ def _get_chat_completion(
+ self,
+ response: Any,
+ ) -> Generator[LLMChatCompletionChunk, None, None]:
+ for part in response:
+ yield LLMChatCompletionChunk(**part.dict())
+
+ def _get_base_args(
+ self,
+ generation_config: GenerationConfig,
+ prompt=None,
+ ) -> dict:
+ """Get the base arguments for the LiteLLM API."""
+ args = {
+ "model": generation_config.model,
+ "temperature": generation_config.temperature,
+ "top_p": generation_config.top_p,
+ "stream": generation_config.stream,
+ # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+ "max_tokens": generation_config.max_tokens_to_sample,
+ }
+ return args
+
+ async def aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> LLMChatCompletion:
+ if generation_config.stream:
+ raise ValueError(
+ "Stream must be set to False to use the `aget_completion` method."
+ )
+ return await self._aget_completion(
+ messages, generation_config, **kwargs
+ )
+
+ async def _aget_completion(
+ self,
+ messages: list[dict],
+ generation_config: GenerationConfig,
+ **kwargs,
+ ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+ """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+ # Create a dictionary with the default arguments
+ args = self._get_base_args(generation_config)
+
+ args["messages"] = messages
+
+ # Conditionally add the 'functions' argument if it's not None
+ if generation_config.functions is not None:
+ args["functions"] = generation_config.functions
+
+ args = {**args, **kwargs}
+ # Create the chat completion
+ return await self.litellm_acompletion(**args)