diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py | 2870 |
1 files changed, 2870 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py new file mode 100644 index 00000000..deb70b48 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/openai.py @@ -0,0 +1,2870 @@ +import time +import types +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Iterable, + Iterator, + List, + Literal, + Optional, + Union, + cast, +) +from urllib.parse import urlparse + +import httpx +import openai +from openai import AsyncOpenAI, OpenAI +from openai.types.beta.assistant_deleted import AssistantDeleted +from openai.types.file_deleted import FileDeleted +from pydantic import BaseModel +from typing_extensions import overload + +import litellm +from litellm import LlmProviders +from litellm._logging import verbose_logger +from litellm.constants import DEFAULT_MAX_RETRIES +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.logging_utils import track_llm_api_timing +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + LiteLLMBatch, + ModelResponse, + ModelResponseStream, +) +from litellm.utils import ( + CustomStreamWrapper, + ProviderConfigManager, + convert_to_model_response_object, +) + +from ...types.llms.openai import * +from ..base import BaseLLM +from .chat.o_series_transformation import OpenAIOSeriesConfig +from .common_utils import ( + BaseOpenAILLM, + OpenAIError, + drop_params_from_unprocessable_entity_error, +) + +openaiOSeriesConfig = OpenAIOSeriesConfig() + + +class MistralEmbeddingConfig: + """ + Reference: https://docs.mistral.ai/api/#operation/createEmbedding + """ + + def __init__( + self, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "encoding_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "encoding_format": + optional_params["encoding_format"] = value + return optional_params + + +class OpenAIConfig(BaseConfig): + """ + Reference: https://platform.openai.com/docs/api-reference/chat/create + + The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: + + - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. + + - `function_call` (string or object): This optional parameter controls how the model calls functions. + + - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. OpenAI has now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + + - `max_completion_tokens` (integer or null): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. + + - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. + + - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + frequency_penalty: Optional[int] = None + function_call: Optional[Union[str, dict]] = None + functions: Optional[list] = None + logit_bias: Optional[dict] = None + max_completion_tokens: Optional[int] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + response_format: Optional[dict] = None + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + response_format: Optional[dict] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> list: + """ + This function returns the list + of supported openai parameters for a given OpenAI Model + + - If O1 model, returns O1 supported params + - If gpt-audio model, returns gpt-audio supported params + - Else, returns gpt supported params + + Args: + model (str): OpenAI model + + Returns: + list: List of supported openai parameters + """ + if openaiOSeriesConfig.is_model_o_series_model(model=model): + return openaiOSeriesConfig.get_supported_openai_params(model=model) + elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model): + return litellm.openAIGPTAudioConfig.get_supported_openai_params(model=model) + else: + return litellm.openAIGPTConfig.get_supported_openai_params(model=model) + + def _map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params + + def _transform_messages( + self, messages: List[AllMessageValues], model: str + ) -> List[AllMessageValues]: + return messages + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ """ + if openaiOSeriesConfig.is_model_o_series_model(model=model): + return openaiOSeriesConfig.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=drop_params, + ) + elif litellm.openAIGPTAudioConfig.is_model_gpt_audio_model(model=model): + return litellm.openAIGPTAudioConfig.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=drop_params, + ) + + return litellm.openAIGPTConfig.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=drop_params, + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return OpenAIError( + status_code=status_code, + message=error_message, + headers=headers, + ) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + messages = self._transform_messages(messages=messages, model=model) + return {"model": model, "messages": messages, **optional_params} + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + + logging_obj.post_call(original_response=raw_response.text) + logging_obj.model_call_details["response_headers"] = raw_response.headers + final_response_obj = cast( + ModelResponse, + convert_to_model_response_object( + response_object=raw_response.json(), + model_response_object=model_response, + hidden_params={"headers": raw_response.headers}, + _response_headers=dict(raw_response.headers), + ), + ) + + return final_response_obj + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return { + "Authorization": f"Bearer {api_key}", + **headers, + } + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return OpenAIChatCompletionResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class OpenAIChatCompletionResponseIterator(BaseModelResponseIterator): + def chunk_parser(self, chunk: dict) -> ModelResponseStream: + """ + {'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'finish_reason': None, 'index': 0, 'logprobs': None}], 'created': 1735763082, 'id': 'a83a2b0fbfaf4aab9c2c93cb8ba346d7', 'model': 'mistral-large', 'object': 'chat.completion.chunk'} + """ + try: + return ModelResponseStream(**chunk) + except Exception as e: + raise e + + +class OpenAIChatCompletion(BaseLLM, BaseOpenAILLM): + + def __init__(self) -> None: + super().__init__() + + def _set_dynamic_params_on_client( + self, + client: Union[OpenAI, AsyncOpenAI], + organization: Optional[str] = None, + max_retries: Optional[int] = None, + ): + if organization is not None: + client.organization = organization + if max_retries is not None: + client.max_retries = max_retries + + def _get_openai_client( + self, + is_async: bool, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + max_retries: Optional[int] = DEFAULT_MAX_RETRIES, + organization: Optional[str] = None, + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + client_initialization_params: Dict = locals() + if client is None: + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, + message="max retries must be an int. Passed in value: {}".format( + max_retries + ), + ) + cached_client = self.get_cached_openai_client( + client_initialization_params=client_initialization_params, + client_type="openai", + ) + + if cached_client: + if isinstance(cached_client, OpenAI) or isinstance( + cached_client, AsyncOpenAI + ): + return cached_client + if is_async: + _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=OpenAIChatCompletion._get_async_http_client(), + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + _new_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=OpenAIChatCompletion._get_sync_http_client(), + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + + ## SAVE CACHE KEY + self.set_cached_openai_client( + openai_client=_new_client, + client_initialization_params=client_initialization_params, + client_type="openai", + ) + return _new_client + + else: + self._set_dynamic_params_on_client( + client=client, + organization=organization, + max_retries=max_retries, + ) + return client + + @track_llm_api_timing() + async def make_openai_chat_completion_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + ) -> Tuple[dict, BaseModel]: + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + start_time = time.time() + try: + raw_response = ( + await openai_aclient.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + ) + end_time = time.time() + + if hasattr(raw_response, "headers"): + headers = dict(raw_response.headers) + else: + headers = {} + response = raw_response.parse() + return headers, response + except openai.APITimeoutError as e: + end_time = time.time() + time_delta = round(end_time - start_time, 2) + e.message += f" - timeout value={timeout}, time taken={time_delta} seconds" + raise e + except Exception as e: + raise e + + @track_llm_api_timing() + def make_sync_openai_chat_completion_request( + self, + openai_client: OpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + ) -> Tuple[dict, BaseModel]: + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + raw_response = None + try: + raw_response = openai_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + + if hasattr(raw_response, "headers"): + headers = dict(raw_response.headers) + else: + headers = {} + response = raw_response.parse() + return headers, response + except Exception as e: + if raw_response is not None: + raise Exception( + "error - {}, Received response - {}, Type of response - {}".format( + e, raw_response, type(raw_response) + ) + ) + else: + raise e + + def mock_streaming( + self, + response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + model: str, + stream_options: Optional[dict] = None, + ) -> CustomStreamWrapper: + completion_stream = MockResponseIterator(model_response=response) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + stream_options=stream_options, + ) + + return streaming_response + + def completion( # type: ignore # noqa: PLR0915 + self, + model_response: ModelResponse, + timeout: Union[float, httpx.Timeout], + optional_params: dict, + litellm_params: dict, + logging_obj: Any, + model: Optional[str] = None, + messages: Optional[list] = None, + print_verbose: Optional[Callable] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + dynamic_params: Optional[bool] = None, + azure_ad_token: Optional[str] = None, + acompletion: bool = False, + logger_fn=None, + headers: Optional[dict] = None, + custom_prompt_dict: dict = {}, + client=None, + organization: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + drop_params: Optional[bool] = None, + ): + + super().completion() + try: + fake_stream: bool = False + inference_params = optional_params.copy() + stream_options: Optional[dict] = inference_params.pop( + "stream_options", None + ) + stream: Optional[bool] = inference_params.pop("stream", False) + provider_config: Optional[BaseConfig] = None + + if custom_llm_provider is not None and model is not None: + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, provider=LlmProviders(custom_llm_provider) + ) + + if provider_config: + fake_stream = provider_config.should_fake_stream( + model=model, custom_llm_provider=custom_llm_provider, stream=stream + ) + + if headers: + inference_params["extra_headers"] = headers + if model is None or messages is None: + raise OpenAIError(status_code=422, message="Missing model or messages") + + if not isinstance(timeout, float) and not isinstance( + timeout, httpx.Timeout + ): + raise OpenAIError( + status_code=422, + message="Timeout needs to be a float or httpx.Timeout", + ) + + if custom_llm_provider is not None and custom_llm_provider != "openai": + model_response.model = f"{custom_llm_provider}/{model}" + + for _ in range( + 2 + ): # if call fails due to alternating messages, retry with reformatted message + + if provider_config is not None: + data = provider_config.transform_request( + model=model, + messages=messages, + optional_params=inference_params, + litellm_params=litellm_params, + headers=headers or {}, + ) + else: + data = OpenAIConfig().transform_request( + model=model, + messages=messages, + optional_params=inference_params, + litellm_params=litellm_params, + headers=headers or {}, + ) + try: + max_retries = data.pop("max_retries", 2) + if acompletion is True: + if stream is True and fake_stream is False: + return self.async_streaming( + logging_obj=logging_obj, + headers=headers, + data=data, + model=model, + api_base=api_base, + api_key=api_key, + api_version=api_version, + timeout=timeout, + client=client, + max_retries=max_retries, + organization=organization, + drop_params=drop_params, + stream_options=stream_options, + ) + else: + return self.acompletion( + data=data, + headers=headers, + model=model, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + api_key=api_key, + api_version=api_version, + timeout=timeout, + client=client, + max_retries=max_retries, + organization=organization, + drop_params=drop_params, + fake_stream=fake_stream, + ) + elif stream is True and fake_stream is False: + return self.streaming( + logging_obj=logging_obj, + headers=headers, + data=data, + model=model, + api_base=api_base, + api_key=api_key, + api_version=api_version, + timeout=timeout, + client=client, + max_retries=max_retries, + organization=organization, + stream_options=stream_options, + ) + else: + if not isinstance(max_retries, int): + raise OpenAIError( + status_code=422, message="max retries must be an int" + ) + openai_client: OpenAI = self._get_openai_client( # type: ignore + is_async=False, + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=openai_client.api_key, + additional_args={ + "headers": headers, + "api_base": openai_client._base_url._uri_reference, + "acompletion": acompletion, + "complete_input_dict": data, + }, + ) + + headers, response = ( + self.make_sync_openai_chat_completion_request( + openai_client=openai_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + ) + + logging_obj.model_call_details["response_headers"] = headers + stringified_response = response.model_dump() + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + + final_response_obj = convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + _response_headers=headers, + ) + if fake_stream is True: + return self.mock_streaming( + response=cast(ModelResponse, final_response_obj), + logging_obj=logging_obj, + model=model, + stream_options=stream_options, + ) + + return final_response_obj + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True or drop_params is True: + inference_params = drop_params_from_unprocessable_entity_error( + e, inference_params + ) + else: + raise e + # e.message + except Exception as e: + if print_verbose is not None: + print_verbose(f"openai.py: Received openai error - {str(e)}") + if ( + "Conversation roles must alternate user/assistant" in str(e) + or "user and assistant roles should be alternating" in str(e) + ) and messages is not None: + if print_verbose is not None: + print_verbose("openai.py: REFORMATS THE MESSAGE!") + # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility + new_messages = [] + for i in range(len(messages) - 1): # type: ignore + new_messages.append(messages[i]) + if messages[i]["role"] == messages[i + 1]["role"]: + if messages[i]["role"] == "user": + new_messages.append( + {"role": "assistant", "content": ""} + ) + else: + new_messages.append({"role": "user", "content": ""}) + new_messages.append(messages[-1]) + messages = new_messages + elif ( + "Last message must have role `user`" in str(e) + ) and messages is not None: + new_messages = messages + new_messages.append({"role": "user", "content": ""}) + messages = new_messages + elif "unknown field: parameter index is not a valid field" in str( + e + ): + litellm.remove_index_from_tool_calls(messages=messages) + else: + raise e + except OpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + error_body = getattr(e, "body", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, + message=error_text, + headers=error_headers, + body=error_body, + ) + + async def acompletion( + self, + data: dict, + model: str, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + timeout: Union[float, httpx.Timeout], + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + organization: Optional[str] = None, + client=None, + max_retries=None, + headers=None, + drop_params: Optional[bool] = None, + stream_options: Optional[dict] = None, + fake_stream: bool = False, + ): + response = None + for _ in range( + 2 + ): # if call fails due to alternating messages, retry with reformatted message + + try: + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore + is_async=True, + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=openai_aclient.api_key, + additional_args={ + "headers": { + "Authorization": f"Bearer {openai_aclient.api_key}" + }, + "api_base": openai_aclient._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + stringified_response = response.model_dump() + + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + logging_obj.model_call_details["response_headers"] = headers + final_response_obj = convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + hidden_params={"headers": headers}, + _response_headers=headers, + ) + + if fake_stream is True: + return self.mock_streaming( + response=cast(ModelResponse, final_response_obj), + logging_obj=logging_obj, + model=model, + stream_options=stream_options, + ) + + return final_response_obj + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True or drop_params is True: + data = drop_params_from_unprocessable_entity_error(e, data) + else: + raise e + # e.message + except Exception as e: + exception_response = getattr(e, "response", None) + status_code = getattr(e, "status_code", 500) + exception_body = getattr(e, "body", None) + error_headers = getattr(e, "headers", None) + if error_headers is None and exception_response: + error_headers = getattr(exception_response, "headers", None) + message = getattr(e, "message", str(e)) + + raise OpenAIError( + status_code=status_code, + message=message, + headers=error_headers, + body=exception_body, + ) + + def streaming( + self, + logging_obj, + timeout: Union[float, httpx.Timeout], + data: dict, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + organization: Optional[str] = None, + client=None, + max_retries=None, + headers=None, + stream_options: Optional[dict] = None, + ): + data["stream"] = True + data.update( + self.get_stream_options(stream_options=stream_options, api_base=api_base) + ) + + openai_client: OpenAI = self._get_openai_client( # type: ignore + is_async=False, + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, + "api_base": openai_client._base_url._uri_reference, + "acompletion": False, + "complete_input_dict": data, + }, + ) + headers, response = self.make_sync_openai_chat_completion_request( + openai_client=openai_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + + logging_obj.model_call_details["response_headers"] = headers + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + _response_headers=headers, + ) + return streamwrapper + + async def async_streaming( + self, + timeout: Union[float, httpx.Timeout], + data: dict, + model: str, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + organization: Optional[str] = None, + client=None, + max_retries=None, + headers=None, + drop_params: Optional[bool] = None, + stream_options: Optional[dict] = None, + ): + response = None + data["stream"] = True + data.update( + self.get_stream_options(stream_options=stream_options, api_base=api_base) + ) + for _ in range(2): + try: + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore + is_async=True, + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + headers, response = await self.make_openai_chat_completion_request( + openai_aclient=openai_aclient, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + logging_obj.model_call_details["response_headers"] = headers + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + _response_headers=headers, + ) + return streamwrapper + except openai.UnprocessableEntityError as e: + ## check if body contains unprocessable params - related issue https://github.com/BerriAI/litellm/issues/4800 + if litellm.drop_params is True or drop_params is True: + data = drop_params_from_unprocessable_entity_error(e, data) + else: + raise e + except ( + Exception + ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. + + if isinstance(e, OpenAIError): + raise e + + error_headers = getattr(e, "headers", None) + status_code = getattr(e, "status_code", 500) + error_response = getattr(e, "response", None) + exception_body = getattr(e, "body", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + if response is not None and hasattr(response, "text"): + raise OpenAIError( + status_code=status_code, + message=f"{str(e)}\n\nOriginal Response: {response.text}", # type: ignore + headers=error_headers, + body=exception_body, + ) + else: + if type(e).__name__ == "ReadTimeout": + raise OpenAIError( + status_code=408, + message=f"{type(e).__name__}", + headers=error_headers, + body=exception_body, + ) + elif hasattr(e, "status_code"): + raise OpenAIError( + status_code=getattr(e, "status_code", 500), + message=str(e), + headers=error_headers, + body=exception_body, + ) + else: + raise OpenAIError( + status_code=500, + message=f"{str(e)}", + headers=error_headers, + body=exception_body, + ) + + def get_stream_options( + self, stream_options: Optional[dict], api_base: Optional[str] + ) -> dict: + """ + Pass `stream_options` to the data dict for OpenAI requests + """ + if stream_options is not None: + return {"stream_options": stream_options} + else: + # by default litellm will include usage for openai endpoints + if api_base is None or urlparse(api_base).hostname == "api.openai.com": + return {"stream_options": {"include_usage": True}} + return {} + + # Embedding + @track_llm_api_timing() + async def make_openai_embedding_request( + self, + openai_aclient: AsyncOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + ): + """ + Helper to: + - call embeddings.create.with_raw_response when litellm.return_response_headers is True + - call embeddings.create by default + """ + try: + raw_response = await openai_aclient.embeddings.with_raw_response.create( + **data, timeout=timeout + ) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + except Exception as e: + raise e + + @track_llm_api_timing() + def make_sync_openai_embedding_request( + self, + openai_client: OpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + ): + """ + Helper to: + - call embeddings.create.with_raw_response when litellm.return_response_headers is True + - call embeddings.create by default + """ + try: + raw_response = openai_client.embeddings.with_raw_response.create( + **data, timeout=timeout + ) # type: ignore + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + except Exception as e: + raise e + + async def aembedding( + self, + input: list, + data: dict, + model_response: EmbeddingResponse, + timeout: float, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client: Optional[AsyncOpenAI] = None, + max_retries=None, + ): + try: + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + headers, response = await self.make_openai_embedding_request( + openai_aclient=openai_aclient, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + logging_obj.model_call_details["response_headers"] = headers + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + returned_response: EmbeddingResponse = convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + response_type="embedding", + _response_headers=headers, + ) # type: ignore + return returned_response + except OpenAIError as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + def embedding( # type: ignore + self, + model: str, + input: list, + timeout: float, + logging_obj, + model_response: EmbeddingResponse, + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + aembedding=None, + max_retries: Optional[int] = None, + ) -> EmbeddingResponse: + super().embedding() + try: + model = model + data = {"model": model, "input": input, **optional_params} + max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES + if not isinstance(max_retries, int): + raise OpenAIError(status_code=422, message="max retries must be an int") + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + ) + + if aembedding is True: + return self.aembedding( # type: ignore + data=data, + input=input, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + api_key=api_key, + timeout=timeout, + client=client, + max_retries=max_retries, + ) + + openai_client: OpenAI = self._get_openai_client( # type: ignore + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + ## embedding CALL + headers: Optional[Dict] = None + headers, sync_embedding_response = self.make_sync_openai_embedding_request( + openai_client=openai_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) # type: ignore + + ## LOGGING + logging_obj.model_call_details["response_headers"] = headers + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=sync_embedding_response, + ) + response: EmbeddingResponse = convert_to_model_response_object( + response_object=sync_embedding_response.model_dump(), + model_response_object=model_response, + _response_headers=headers, + response_type="embedding", + ) # type: ignore + return response + except OpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + async def aimage_generation( + self, + prompt: str, + data: dict, + model_response: ModelResponse, + timeout: float, + logging_obj: Any, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + ): + response = None + try: + + openai_aclient = self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=str(e), + ) + raise e + + def image_generation( + self, + model: Optional[str], + prompt: str, + timeout: float, + optional_params: dict, + logging_obj: Any, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + model_response: Optional[ImageResponse] = None, + client=None, + aimg_generation=None, + ) -> ImageResponse: + data = {} + try: + model = model + data = {"model": model, "prompt": prompt, **optional_params} + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise OpenAIError(status_code=422, message="max retries must be an int") + + if aimg_generation is True: + return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + + openai_client: OpenAI = self._get_openai_client( # type: ignore + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=openai_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, + "api_base": openai_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + ## COMPLETION CALL + _response = openai_client.images.generate(**data, timeout=timeout) # type: ignore + + response = _response.model_dump() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore + except OpenAIError as e: + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + if hasattr(e, "status_code"): + raise OpenAIError( + status_code=getattr(e, "status_code", 500), message=str(e) + ) + else: + raise OpenAIError(status_code=500, message=str(e)) + + def audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + organization: Optional[str], + project: Optional[str], + max_retries: int, + timeout: Union[float, httpx.Timeout], + aspeech: Optional[bool] = None, + client=None, + ) -> HttpxBinaryResponseContent: + + if aspeech is not None and aspeech is True: + return self.async_audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + organization=organization, + project=project, + max_retries=max_retries, + timeout=timeout, + client=client, + ) # type: ignore + + openai_client = self._get_openai_client( + is_async=False, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = cast(OpenAI, openai_client).audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + return HttpxBinaryResponseContent(response=response.response) + + async def async_audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + organization: Optional[str], + project: Optional[str], + max_retries: int, + timeout: Union[float, httpx.Timeout], + client=None, + ) -> HttpxBinaryResponseContent: + + openai_client = cast( + AsyncOpenAI, + self._get_openai_client( + is_async=True, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + client=client, + ), + ) + + response = await openai_client.audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + + return HttpxBinaryResponseContent(response=response.response) + + +class OpenAIFilesAPI(BaseLLM): + """ + OpenAI methods to support for batches + - create_file() + - retrieve_file() + - list_files() + - delete_file() + - file_content() + - update_file() + """ + + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + received_args = locals() + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + if _is_async is True: + openai_client = AsyncOpenAI(**data) + else: + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + async def acreate_file( + self, + create_file_data: CreateFileRequest, + openai_client: AsyncOpenAI, + ) -> FileObject: + response = await openai_client.files.create(**create_file_data) + return response + + def create_file( + self, + _is_async: bool, + create_file_data: CreateFileRequest, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_file( # type: ignore + create_file_data=create_file_data, openai_client=openai_client + ) + response = openai_client.files.create(**create_file_data) + return response + + async def afile_content( + self, + file_content_request: FileContentRequest, + openai_client: AsyncOpenAI, + ) -> HttpxBinaryResponseContent: + response = await openai_client.files.content(**file_content_request) + return HttpxBinaryResponseContent(response=response.response) + + def file_content( + self, + _is_async: bool, + file_content_request: FileContentRequest, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[ + HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] + ]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.afile_content( # type: ignore + file_content_request=file_content_request, + openai_client=openai_client, + ) + response = cast(OpenAI, openai_client).files.content(**file_content_request) + + return HttpxBinaryResponseContent(response=response.response) + + async def aretrieve_file( + self, + file_id: str, + openai_client: AsyncOpenAI, + ) -> FileObject: + response = await openai_client.files.retrieve(file_id=file_id) + return response + + def retrieve_file( + self, + _is_async: bool, + file_id: str, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.retrieve(file_id=file_id) + + return response + + async def adelete_file( + self, + file_id: str, + openai_client: AsyncOpenAI, + ) -> FileDeleted: + response = await openai_client.files.delete(file_id=file_id) + return response + + def delete_file( + self, + _is_async: bool, + file_id: str, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.adelete_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.delete(file_id=file_id) + + return response + + async def alist_files( + self, + openai_client: AsyncOpenAI, + purpose: Optional[str] = None, + ): + if isinstance(purpose, str): + response = await openai_client.files.list(purpose=purpose) + else: + response = await openai_client.files.list() + return response + + def list_files( + self, + _is_async: bool, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + purpose: Optional[str] = None, + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.alist_files( # type: ignore + purpose=purpose, + openai_client=openai_client, + ) + + if isinstance(purpose, str): + response = openai_client.files.list(purpose=purpose) + else: + response = openai_client.files.list() + + return response + + +class OpenAIBatchesAPI(BaseLLM): + """ + OpenAI methods to support for batches + - create_batch() + - retrieve_batch() + - cancel_batch() + - list_batch() + """ + + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[OpenAI, AsyncOpenAI]]: + received_args = locals() + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + if _is_async is True: + openai_client = AsyncOpenAI(**data) + else: + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + async def acreate_batch( + self, + create_batch_data: CreateBatchRequest, + openai_client: AsyncOpenAI, + ) -> LiteLLMBatch: + response = await openai_client.batches.create(**create_batch_data) + return LiteLLMBatch(**response.model_dump()) + + def create_batch( + self, + _is_async: bool, + create_batch_data: CreateBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_batch( # type: ignore + create_batch_data=create_batch_data, openai_client=openai_client + ) + response = cast(OpenAI, openai_client).batches.create(**create_batch_data) + + return LiteLLMBatch(**response.model_dump()) + + async def aretrieve_batch( + self, + retrieve_batch_data: RetrieveBatchRequest, + openai_client: AsyncOpenAI, + ) -> LiteLLMBatch: + verbose_logger.debug("retrieving batch, args= %s", retrieve_batch_data) + response = await openai_client.batches.retrieve(**retrieve_batch_data) + return LiteLLMBatch(**response.model_dump()) + + def retrieve_batch( + self, + _is_async: bool, + retrieve_batch_data: RetrieveBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_batch( # type: ignore + retrieve_batch_data=retrieve_batch_data, openai_client=openai_client + ) + response = cast(OpenAI, openai_client).batches.retrieve(**retrieve_batch_data) + return LiteLLMBatch(**response.model_dump()) + + async def acancel_batch( + self, + cancel_batch_data: CancelBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data) + response = await openai_client.batches.cancel(**cancel_batch_data) + return response + + def cancel_batch( + self, + _is_async: bool, + cancel_batch_data: CancelBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acancel_batch( # type: ignore + cancel_batch_data=cancel_batch_data, openai_client=openai_client + ) + + response = openai_client.batches.cancel(**cancel_batch_data) + return response + + async def alist_batches( + self, + openai_client: AsyncOpenAI, + after: Optional[str] = None, + limit: Optional[int] = None, + ): + verbose_logger.debug("listing batches, after= %s, limit= %s", after, limit) + response = await openai_client.batches.list(after=after, limit=limit) # type: ignore + return response + + def list_batches( + self, + _is_async: bool, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + after: Optional[str] = None, + limit: Optional[int] = None, + client: Optional[OpenAI] = None, + ): + openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + if openai_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.alist_batches( # type: ignore + openai_client=openai_client, after=after, limit=limit + ) + response = openai_client.batches.list(after=after, limit=limit) # type: ignore + return response + + +class OpenAIAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + openai_client = OpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + def async_get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> AsyncOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + openai_client = AsyncOpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + ### ASSISTANTS ### + + async def async_get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + order: Optional[str] = "desc", + limit: Optional[int] = 20, + before: Optional[str] = None, + after: Optional[str] = None, + ) -> AsyncCursorPage[Assistant]: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + request_params = { + "order": order, + "limit": limit, + } + if before: + request_params["before"] = before + if after: + request_params["after"] = after + + response = await openai_client.beta.assistants.list(**request_params) # type: ignore + + return response + + # fmt: off + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_assistants: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: + ... + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_assistants: Optional[Literal[False]], + ) -> SyncCursorPage[Assistant]: + ... + + # fmt: on + + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client=None, + aget_assistants=None, + order: Optional[str] = "desc", + limit: Optional[int] = 20, + before: Optional[str] = None, + after: Optional[str] = None, + ): + if aget_assistants is not None and aget_assistants is True: + return self.async_get_assistants( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + request_params = { + "order": order, + "limit": limit, + } + + if before: + request_params["before"] = before + if after: + request_params["after"] = after + + response = openai_client.beta.assistants.list(**request_params) # type: ignore + + return response + + # Create Assistant + async def async_create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + create_assistant_data: dict, + ) -> Assistant: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.assistants.create(**create_assistant_data) + + return response + + def create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + create_assistant_data: dict, + client=None, + async_create_assistants=None, + ): + if async_create_assistants is not None and async_create_assistants is True: + return self.async_create_assistants( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + create_assistant_data=create_assistant_data, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.assistants.create(**create_assistant_data) + return response + + # Delete Assistant + async def async_delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + assistant_id: str, + ) -> AssistantDeleted: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.assistants.delete(assistant_id=assistant_id) + + return response + + def delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + assistant_id: str, + client=None, + async_delete_assistants=None, + ): + if async_delete_assistants is not None and async_delete_assistants is True: + return self.async_delete_assistant( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + assistant_id=assistant_id, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.assistants.delete(assistant_id=assistant_id) + return response + + ### MESSAGES ### + + async def a_add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> OpenAIMessage: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + # fmt: off + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + a_add_message: Literal[True], + ) -> Coroutine[None, None, OpenAIMessage]: + ... + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + a_add_message: Optional[Literal[False]], + ) -> OpenAIMessage: + ... + + # fmt: on + + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client=None, + a_add_message: Optional[bool] = None, + ): + if a_add_message is not None and a_add_message is True: + return self.a_add_message( + thread_id=thread_id, + message_data=message_data, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + async def async_get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> AsyncCursorPage[OpenAIMessage]: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + # fmt: off + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_messages: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: + ... + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_messages: Optional[Literal[False]], + ) -> SyncCursorPage[OpenAIMessage]: + ... + + # fmt: on + + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client=None, + aget_messages=None, + ): + if aget_messages is not None and aget_messages is True: + return self.async_get_messages( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + async def async_create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + ) -> Thread: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = await openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + # fmt: off + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AsyncOpenAI], + acreate_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[OpenAI], + acreate_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client=None, + acreate_thread=None, + ): + """ + Here's an example: + ``` + from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + if acreate_thread is not None and acreate_thread is True: + return self.async_create_thread( + metadata=metadata, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + messages=messages, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + async def async_get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + ) -> Thread: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # fmt: off + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client=None, + aget_thread=None, + ): + if aget_thread is not None and aget_thread is True: + return self.async_get_thread( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + def delete_thread(self): + pass + + ### RUNS ### + + async def arun_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + ) -> Run: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response + + def async_run_thread_stream( + self, + client: AsyncOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: + data: Dict[str, Any] = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + def run_thread_stream( + self, + client: OpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AssistantStreamManager[AssistantEventHandler]: + data: Dict[str, Any] = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + # fmt: off + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client, + arun_thread: Literal[True], + event_handler: Optional[AssistantEventHandler], + ) -> Coroutine[None, None, Run]: + ... + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client, + arun_thread: Optional[Literal[False]], + event_handler: Optional[AssistantEventHandler], + ) -> Run: + ... + + # fmt: on + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client=None, + arun_thread=None, + event_handler: Optional[AssistantEventHandler] = None, + ): + if arun_thread is not None and arun_thread is True: + if stream is not None and stream is True: + _client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + return self.async_run_thread_stream( + client=_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + return self.arun_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + if stream is not None and stream is True: + return self.run_thread_stream( + client=openai_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response |