about summary refs log tree commit diff
path: root/R2R/r2r/providers/llms/openai
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/llms/openai')
-rwxr-xr-xR2R/r2r/providers/llms/openai/base_openai.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/R2R/r2r/providers/llms/openai/base_openai.py b/R2R/r2r/providers/llms/openai/base_openai.py
new file mode 100755
index 00000000..460c0f0b
--- /dev/null
+++ b/R2R/r2r/providers/llms/openai/base_openai.py
@@ -0,0 +1,144 @@
+"""A module for creating OpenAI model abstractions."""
+
+import logging
+import os
+from typing import Union
+
+from r2r.base import (
+    LLMChatCompletion,
+    LLMChatCompletionChunk,
+    LLMConfig,
+    LLMProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAILLM(LLMProvider):
+    """A concrete class for creating OpenAI models."""
+
+    def __init__(
+        self,
+        config: LLMConfig,
+        *args,
+        **kwargs,
+    ) -> None:
+        if not isinstance(config, LLMConfig):
+            raise ValueError(
+                "The provided config must be an instance of OpenAIConfig."
+            )
+        try:
+            from openai import OpenAI  # noqa
+        except ImportError:
+            raise ImportError(
+                "Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
+            )
+        if config.provider != "openai":
+            raise ValueError(
+                "OpenAILLM must be initialized with config with `openai` provider."
+            )
+        if not os.getenv("OPENAI_API_KEY"):
+            raise ValueError(
+                "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
+            )
+        super().__init__(config)
+        self.config: LLMConfig = config
+        self.client = OpenAI()
+
+    def get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `get_completion` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def get_completion_stream(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletionChunk:
+        if not generation_config.stream:
+            raise ValueError(
+                "Stream must be set to True to use the `get_completion_stream` method."
+            )
+        return self._get_completion(messages, generation_config, **kwargs)
+
+    def _get_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return self.client.chat.completions.create(**args)
+
+    def _get_base_args(
+        self,
+        generation_config: GenerationConfig,
+    ) -> dict:
+        """Get the base arguments for the OpenAI API."""
+
+        args = {
+            "model": generation_config.model,
+            "temperature": generation_config.temperature,
+            "top_p": generation_config.top_p,
+            "stream": generation_config.stream,
+            # TODO - We need to cap this to avoid potential errors when exceed max allowable context
+            "max_tokens": generation_config.max_tokens_to_sample,
+        }
+
+        return args
+
+    async def aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> LLMChatCompletion:
+        if generation_config.stream:
+            raise ValueError(
+                "Stream must be set to False to use the `aget_completion` method."
+            )
+        return await self._aget_completion(
+            messages, generation_config, **kwargs
+        )
+
+    async def _aget_completion(
+        self,
+        messages: list[dict],
+        generation_config: GenerationConfig,
+        **kwargs,
+    ) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
+        """Asynchronously get a completion from the OpenAI API based on the provided messages."""
+
+        # Create a dictionary with the default arguments
+        args = self._get_base_args(generation_config)
+
+        args["messages"] = messages
+
+        # Conditionally add the 'functions' argument if it's not None
+        if generation_config.functions is not None:
+            args["functions"] = generation_config.functions
+
+        args = {**args, **kwargs}
+        # Create the chat completion
+        return await self.client.chat.completions.create(**args)