diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py | 90 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py | 110 |
2 files changed, 200 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py new file mode 100644 index 00000000..8ea19d41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py @@ -0,0 +1,90 @@ +from typing import Callable, Optional, Union + +import httpx + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import CustomStreamingDecoder, ModelResponse + +from ...openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import _get_api_params +from .transformation import IBMWatsonXChatConfig + +watsonx_chat_transformation = IBMWatsonXChatConfig() + + +class WatsonXChatHandler(OpenAILikeChatHandler): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params: dict = {}, + headers: Optional[dict] = None, + logger_fn=None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + api_params = _get_api_params(params=optional_params) + + ## UPDATE HEADERS + headers = watsonx_chat_transformation.validate_environment( + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + ## UPDATE PAYLOAD (optional params) + watsonx_auth_payload = watsonx_chat_transformation._prepare_payload( + model=model, + api_params=api_params, + ) + optional_params.update(watsonx_auth_payload) + + ## GET API URL + api_base = watsonx_chat_transformation.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=optional_params.get("stream", False), + ) + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=True, + streaming_decoder=streaming_decoder, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py new file mode 100644 index 00000000..f253da6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py @@ -0,0 +1,110 @@ +""" +Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint. + +Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat +""" + +from typing import List, Optional, Tuple, Union + +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.watsonx import WatsonXAIEndpoint + +from ....utils import _remove_additional_properties, _remove_strict_from_schema +from ...openai.chat.gpt_transformation import OpenAIGPTConfig +from ..common_utils import IBMWatsonXMixin + + +class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): + + def get_supported_openai_params(self, model: str) -> List: + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + "tools", + "tool_choice", # equivalent to tool_choice + tool_choice_options + "logprobs", + "top_logprobs", + "n", + "presence_penalty", + "response_format", + ] + + def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool: + if tool_choice is None: + return False + if isinstance(tool_choice, str): + return tool_choice in ["auto", "none", "required"] + return False + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + ## TOOLS ## + _tools = non_default_params.pop("tools", None) + if _tools is not None: + # remove 'additionalProperties' from tools + _tools = _remove_additional_properties(_tools) + # remove 'strict' from tools + _tools = _remove_strict_from_schema(_tools) + if _tools is not None: + non_default_params["tools"] = _tools + + ## TOOL CHOICE ## + + _tool_choice = non_default_params.pop("tool_choice", None) + if self.is_tool_choice_option(_tool_choice): + optional_params["tool_choice_options"] = _tool_choice + elif _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice + return super().map_openai_params( + non_default_params, optional_params, model, drop_params + ) + + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore + dynamic_api_key = ( + api_key or get_secret_str("HOSTED_VLLM_API_KEY") or "" + ) # vllm does not require an api key + return api_base, dynamic_api_key + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + url = self._get_base_url(api_base=api_base) + if model.startswith("deployment/"): + deployment_id = "/".join(model.split("/")[1:]) + endpoint = ( + WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value + if stream + else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + endpoint = ( + WatsonXAIEndpoint.CHAT_STREAM.value + if stream + else WatsonXAIEndpoint.CHAT.value + ) + url = url.rstrip("/") + endpoint + + ## add api version + url = self._add_api_version_to_url( + url=url, api_version=optional_params.pop("api_version", None) + ) + return url |