"""A module for creating OpenAI model abstractions."""
import logging
import os
from typing import Union
from r2r.base import (
LLMChatCompletion,
LLMChatCompletionChunk,
LLMConfig,
LLMProvider,
)
from r2r.base.abstractions.llm import GenerationConfig
logger = logging.getLogger(__name__)
class OpenAILLM(LLMProvider):
"""A concrete class for creating OpenAI models."""
def __init__(
self,
config: LLMConfig,
*args,
**kwargs,
) -> None:
if not isinstance(config, LLMConfig):
raise ValueError(
"The provided config must be an instance of OpenAIConfig."
)
try:
from openai import OpenAI # noqa
except ImportError:
raise ImportError(
"Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
)
if config.provider != "openai":
raise ValueError(
"OpenAILLM must be initialized with config with `openai` provider."
)
if not os.getenv("OPENAI_API_KEY"):
raise ValueError(
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
)
super().__init__(config)
self.config: LLMConfig = config
self.client = OpenAI()
def get_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> LLMChatCompletion:
if generation_config.stream:
raise ValueError(
"Stream must be set to False to use the `get_completion` method."
)
return self._get_completion(messages, generation_config, **kwargs)
def get_completion_stream(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> LLMChatCompletionChunk:
if not generation_config.stream:
raise ValueError(
"Stream must be set to True to use the `get_completion_stream` method."
)
return self._get_completion(messages, generation_config, **kwargs)
def _get_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
"""Get a completion from the OpenAI API based on the provided messages."""
# Create a dictionary with the default arguments
args = self._get_base_args(generation_config)
args["messages"] = messages
# Conditionally add the 'functions' argument if it's not None
if generation_config.functions is not None:
args["functions"] = generation_config.functions
args = {**args, **kwargs}
# Create the chat completion
return self.client.chat.completions.create(**args)
def _get_base_args(
self,
generation_config: GenerationConfig,
) -> dict:
"""Get the base arguments for the OpenAI API."""
args = {
"model": generation_config.model,
"temperature": generation_config.temperature,
"top_p": generation_config.top_p,
"stream": generation_config.stream,
# TODO - We need to cap this to avoid potential errors when exceed max allowable context
"max_tokens": generation_config.max_tokens_to_sample,
}
return args
async def aget_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> LLMChatCompletion:
if generation_config.stream:
raise ValueError(
"Stream must be set to False to use the `aget_completion` method."
)
return await self._aget_completion(
messages, generation_config, **kwargs
)
async def _aget_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> Union[LLMChatCompletion, LLMChatCompletionChunk]:
"""Asynchronously get a completion from the OpenAI API based on the provided messages."""
# Create a dictionary with the default arguments
args = self._get_base_args(generation_config)
args["messages"] = messages
# Conditionally add the 'functions' argument if it's not None
if generation_config.functions is not None:
args["functions"] = generation_config.functions
args = {**args, **kwargs}
# Create the chat completion
return await self.client.chat.completions.create(**args)