about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py90
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py110
2 files changed, 200 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py
new file mode 100644
index 00000000..8ea19d41
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py
@@ -0,0 +1,90 @@
+from typing import Callable, Optional, Union
+
+import httpx
+
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.utils import CustomStreamingDecoder, ModelResponse
+
+from ...openai_like.chat.handler import OpenAILikeChatHandler
+from ..common_utils import _get_api_params
+from .transformation import IBMWatsonXChatConfig
+
+watsonx_chat_transformation = IBMWatsonXChatConfig()
+
+
+class WatsonXChatHandler(OpenAILikeChatHandler):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def completion(
+        self,
+        *,
+        model: str,
+        messages: list,
+        api_base: str,
+        custom_llm_provider: str,
+        custom_prompt_dict: dict,
+        model_response: ModelResponse,
+        print_verbose: Callable,
+        encoding,
+        api_key: Optional[str],
+        logging_obj,
+        optional_params: dict,
+        acompletion=None,
+        litellm_params: dict = {},
+        headers: Optional[dict] = None,
+        logger_fn=None,
+        timeout: Optional[Union[float, httpx.Timeout]] = None,
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+        custom_endpoint: Optional[bool] = None,
+        streaming_decoder: Optional[CustomStreamingDecoder] = None,
+        fake_stream: bool = False,
+    ):
+        api_params = _get_api_params(params=optional_params)
+
+        ## UPDATE HEADERS
+        headers = watsonx_chat_transformation.validate_environment(
+            headers=headers or {},
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            api_key=api_key,
+        )
+
+        ## UPDATE PAYLOAD (optional params)
+        watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
+            model=model,
+            api_params=api_params,
+        )
+        optional_params.update(watsonx_auth_payload)
+
+        ## GET API URL
+        api_base = watsonx_chat_transformation.get_complete_url(
+            api_base=api_base,
+            model=model,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            stream=optional_params.get("stream", False),
+        )
+
+        return super().completion(
+            model=model,
+            messages=messages,
+            api_base=api_base,
+            custom_llm_provider=custom_llm_provider,
+            custom_prompt_dict=custom_prompt_dict,
+            model_response=model_response,
+            print_verbose=print_verbose,
+            encoding=encoding,
+            api_key=api_key,
+            logging_obj=logging_obj,
+            optional_params=optional_params,
+            acompletion=acompletion,
+            litellm_params=litellm_params,
+            logger_fn=logger_fn,
+            headers=headers,
+            timeout=timeout,
+            client=client,
+            custom_endpoint=True,
+            streaming_decoder=streaming_decoder,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py
new file mode 100644
index 00000000..f253da6f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py
@@ -0,0 +1,110 @@
+"""
+Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
+
+Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
+"""
+
+from typing import List, Optional, Tuple, Union
+
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.watsonx import WatsonXAIEndpoint
+
+from ....utils import _remove_additional_properties, _remove_strict_from_schema
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+from ..common_utils import IBMWatsonXMixin
+
+
+class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return [
+            "temperature",  # equivalent to temperature
+            "max_tokens",  # equivalent to max_new_tokens
+            "top_p",  # equivalent to top_p
+            "frequency_penalty",  # equivalent to repetition_penalty
+            "stop",  # equivalent to stop_sequences
+            "seed",  # equivalent to random_seed
+            "stream",  # equivalent to stream
+            "tools",
+            "tool_choice",  # equivalent to tool_choice + tool_choice_options
+            "logprobs",
+            "top_logprobs",
+            "n",
+            "presence_penalty",
+            "response_format",
+        ]
+
+    def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
+        if tool_choice is None:
+            return False
+        if isinstance(tool_choice, str):
+            return tool_choice in ["auto", "none", "required"]
+        return False
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        ## TOOLS ##
+        _tools = non_default_params.pop("tools", None)
+        if _tools is not None:
+            # remove 'additionalProperties' from tools
+            _tools = _remove_additional_properties(_tools)
+            # remove 'strict' from tools
+            _tools = _remove_strict_from_schema(_tools)
+        if _tools is not None:
+            non_default_params["tools"] = _tools
+
+        ## TOOL CHOICE ##
+
+        _tool_choice = non_default_params.pop("tool_choice", None)
+        if self.is_tool_choice_option(_tool_choice):
+            optional_params["tool_choice_options"] = _tool_choice
+        elif _tool_choice is not None:
+            optional_params["tool_choice"] = _tool_choice
+        return super().map_openai_params(
+            non_default_params, optional_params, model, drop_params
+        )
+
+    def _get_openai_compatible_provider_info(
+        self, api_base: Optional[str], api_key: Optional[str]
+    ) -> Tuple[Optional[str], Optional[str]]:
+        api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE")  # type: ignore
+        dynamic_api_key = (
+            api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
+        )  # vllm does not require an api key
+        return api_base, dynamic_api_key
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        url = self._get_base_url(api_base=api_base)
+        if model.startswith("deployment/"):
+            deployment_id = "/".join(model.split("/")[1:])
+            endpoint = (
+                WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
+                if stream
+                else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
+            )
+            endpoint = endpoint.format(deployment_id=deployment_id)
+        else:
+            endpoint = (
+                WatsonXAIEndpoint.CHAT_STREAM.value
+                if stream
+                else WatsonXAIEndpoint.CHAT.value
+            )
+        url = url.rstrip("/") + endpoint
+
+        ## add api version
+        url = self._add_api_version_to_url(
+            url=url, api_version=optional_params.pop("api_version", None)
+        )
+        return url