about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/watsonx
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/watsonx
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py90
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py110
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py291
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py3
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py391
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py112
6 files changed, 997 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py
new file mode 100644
index 00000000..8ea19d41
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py
@@ -0,0 +1,90 @@
+from typing import Callable, Optional, Union
+
+import httpx
+
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.utils import CustomStreamingDecoder, ModelResponse
+
+from ...openai_like.chat.handler import OpenAILikeChatHandler
+from ..common_utils import _get_api_params
+from .transformation import IBMWatsonXChatConfig
+
+watsonx_chat_transformation = IBMWatsonXChatConfig()
+
+
+class WatsonXChatHandler(OpenAILikeChatHandler):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def completion(
+        self,
+        *,
+        model: str,
+        messages: list,
+        api_base: str,
+        custom_llm_provider: str,
+        custom_prompt_dict: dict,
+        model_response: ModelResponse,
+        print_verbose: Callable,
+        encoding,
+        api_key: Optional[str],
+        logging_obj,
+        optional_params: dict,
+        acompletion=None,
+        litellm_params: dict = {},
+        headers: Optional[dict] = None,
+        logger_fn=None,
+        timeout: Optional[Union[float, httpx.Timeout]] = None,
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+        custom_endpoint: Optional[bool] = None,
+        streaming_decoder: Optional[CustomStreamingDecoder] = None,
+        fake_stream: bool = False,
+    ):
+        api_params = _get_api_params(params=optional_params)
+
+        ## UPDATE HEADERS
+        headers = watsonx_chat_transformation.validate_environment(
+            headers=headers or {},
+            model=model,
+            messages=messages,
+            optional_params=optional_params,
+            api_key=api_key,
+        )
+
+        ## UPDATE PAYLOAD (optional params)
+        watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
+            model=model,
+            api_params=api_params,
+        )
+        optional_params.update(watsonx_auth_payload)
+
+        ## GET API URL
+        api_base = watsonx_chat_transformation.get_complete_url(
+            api_base=api_base,
+            model=model,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            stream=optional_params.get("stream", False),
+        )
+
+        return super().completion(
+            model=model,
+            messages=messages,
+            api_base=api_base,
+            custom_llm_provider=custom_llm_provider,
+            custom_prompt_dict=custom_prompt_dict,
+            model_response=model_response,
+            print_verbose=print_verbose,
+            encoding=encoding,
+            api_key=api_key,
+            logging_obj=logging_obj,
+            optional_params=optional_params,
+            acompletion=acompletion,
+            litellm_params=litellm_params,
+            logger_fn=logger_fn,
+            headers=headers,
+            timeout=timeout,
+            client=client,
+            custom_endpoint=True,
+            streaming_decoder=streaming_decoder,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py
new file mode 100644
index 00000000..f253da6f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py
@@ -0,0 +1,110 @@
+"""
+Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
+
+Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
+"""
+
+from typing import List, Optional, Tuple, Union
+
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.watsonx import WatsonXAIEndpoint
+
+from ....utils import _remove_additional_properties, _remove_strict_from_schema
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+from ..common_utils import IBMWatsonXMixin
+
+
+class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return [
+            "temperature",  # equivalent to temperature
+            "max_tokens",  # equivalent to max_new_tokens
+            "top_p",  # equivalent to top_p
+            "frequency_penalty",  # equivalent to repetition_penalty
+            "stop",  # equivalent to stop_sequences
+            "seed",  # equivalent to random_seed
+            "stream",  # equivalent to stream
+            "tools",
+            "tool_choice",  # equivalent to tool_choice + tool_choice_options
+            "logprobs",
+            "top_logprobs",
+            "n",
+            "presence_penalty",
+            "response_format",
+        ]
+
+    def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
+        if tool_choice is None:
+            return False
+        if isinstance(tool_choice, str):
+            return tool_choice in ["auto", "none", "required"]
+        return False
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        ## TOOLS ##
+        _tools = non_default_params.pop("tools", None)
+        if _tools is not None:
+            # remove 'additionalProperties' from tools
+            _tools = _remove_additional_properties(_tools)
+            # remove 'strict' from tools
+            _tools = _remove_strict_from_schema(_tools)
+        if _tools is not None:
+            non_default_params["tools"] = _tools
+
+        ## TOOL CHOICE ##
+
+        _tool_choice = non_default_params.pop("tool_choice", None)
+        if self.is_tool_choice_option(_tool_choice):
+            optional_params["tool_choice_options"] = _tool_choice
+        elif _tool_choice is not None:
+            optional_params["tool_choice"] = _tool_choice
+        return super().map_openai_params(
+            non_default_params, optional_params, model, drop_params
+        )
+
+    def _get_openai_compatible_provider_info(
+        self, api_base: Optional[str], api_key: Optional[str]
+    ) -> Tuple[Optional[str], Optional[str]]:
+        api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE")  # type: ignore
+        dynamic_api_key = (
+            api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
+        )  # vllm does not require an api key
+        return api_base, dynamic_api_key
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        url = self._get_base_url(api_base=api_base)
+        if model.startswith("deployment/"):
+            deployment_id = "/".join(model.split("/")[1:])
+            endpoint = (
+                WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
+                if stream
+                else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
+            )
+            endpoint = endpoint.format(deployment_id=deployment_id)
+        else:
+            endpoint = (
+                WatsonXAIEndpoint.CHAT_STREAM.value
+                if stream
+                else WatsonXAIEndpoint.CHAT.value
+            )
+        url = url.rstrip("/") + endpoint
+
+        ## add api version
+        url = self._add_api_version_to_url(
+            url=url, api_version=optional_params.pop("api_version", None)
+        )
+        return url
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
new file mode 100644
index 00000000..4916cd1c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
@@ -0,0 +1,291 @@
+from typing import Dict, List, Optional, Union, cast
+
+import httpx
+
+import litellm
+from litellm import verbose_logger
+from litellm.caching import InMemoryCache
+from litellm.litellm_core_utils.prompt_templates import factory as ptf
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
+
+
+class WatsonXAIError(BaseLLMException):
+    def __init__(
+        self,
+        status_code: int,
+        message: str,
+        headers: Optional[Union[Dict, httpx.Headers]] = None,
+    ):
+        super().__init__(status_code=status_code, message=message, headers=headers)
+
+
+iam_token_cache = InMemoryCache()
+
+
+def get_watsonx_iam_url():
+    return (
+        get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
+    )
+
+
+def generate_iam_token(api_key=None, **params) -> str:
+    result: Optional[str] = iam_token_cache.get_cache(api_key)  # type: ignore
+
+    if result is None:
+        headers = {}
+        headers["Content-Type"] = "application/x-www-form-urlencoded"
+        if api_key is None:
+            api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
+        if api_key is None:
+            raise ValueError("API key is required")
+        headers["Accept"] = "application/json"
+        data = {
+            "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
+            "apikey": api_key,
+        }
+        iam_token_url = get_watsonx_iam_url()
+        verbose_logger.debug(
+            "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
+            iam_token_url,
+            headers,
+            data,
+        )
+        response = litellm.module_level_client.post(
+            url=iam_token_url, data=data, headers=headers
+        )
+        response.raise_for_status()
+        json_data = response.json()
+
+        result = json_data["access_token"]
+        iam_token_cache.set_cache(
+            key=api_key,
+            value=result,
+            ttl=json_data["expires_in"] - 10,  # leave some buffer
+        )
+
+    return cast(str, result)
+
+
+def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
+    if token is not None:
+        return token
+    token = generate_iam_token(api_key)
+    return token
+
+
+def _get_api_params(
+    params: dict,
+) -> WatsonXAPIParams:
+    """
+    Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
+    """
+    # Load auth variables from params
+    project_id = params.pop(
+        "project_id", params.pop("watsonx_project", None)
+    )  # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
+    space_id = params.pop("space_id", None)  # watsonx.ai deployment space_id
+    region_name = params.pop("region_name", params.pop("region", None))
+    if region_name is None:
+        region_name = params.pop(
+            "watsonx_region_name", params.pop("watsonx_region", None)
+        )  # consistent with how vertex ai + aws regions are accepted
+
+    # Load auth variables from environment variables
+    if project_id is None:
+        project_id = (
+            get_secret_str("WATSONX_PROJECT_ID")
+            or get_secret_str("WX_PROJECT_ID")
+            or get_secret_str("PROJECT_ID")
+        )
+    if region_name is None:
+        region_name = (
+            get_secret_str("WATSONX_REGION")
+            or get_secret_str("WX_REGION")
+            or get_secret_str("REGION")
+        )
+    if space_id is None:
+        space_id = (
+            get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
+            or get_secret_str("WATSONX_SPACE_ID")
+            or get_secret_str("WX_SPACE_ID")
+            or get_secret_str("SPACE_ID")
+        )
+
+    if project_id is None:
+        raise WatsonXAIError(
+            status_code=401,
+            message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
+        )
+
+    return WatsonXAPIParams(
+        project_id=project_id,
+        space_id=space_id,
+        region_name=region_name,
+    )
+
+
+def convert_watsonx_messages_to_prompt(
+    model: str,
+    messages: List[AllMessageValues],
+    provider: str,
+    custom_prompt_dict: Dict,
+) -> str:
+    # handle anthropic prompts and amazon titan prompts
+    if model in custom_prompt_dict:
+        # check if the model has a registered custom prompt
+        model_prompt_dict = custom_prompt_dict[model]
+        prompt = ptf.custom_prompt(
+            messages=messages,
+            role_dict=model_prompt_dict.get(
+                "role_dict", model_prompt_dict.get("roles")
+            ),
+            initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
+            final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
+            bos_token=model_prompt_dict.get("bos_token", ""),
+            eos_token=model_prompt_dict.get("eos_token", ""),
+        )
+        return prompt
+    elif provider == "ibm-mistralai":
+        prompt = ptf.mistral_instruct_pt(messages=messages)
+    else:
+        prompt: str = ptf.prompt_factory(  # type: ignore
+            model=model, messages=messages, custom_llm_provider="watsonx"
+        )
+    return prompt
+
+
+# Mixin class for shared IBM Watson X functionality
+class IBMWatsonXMixin:
+    def validate_environment(
+        self,
+        headers: Dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> Dict:
+        default_headers = {
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+        }
+
+        if "Authorization" in headers:
+            return {**default_headers, **headers}
+        token = cast(
+            Optional[str],
+            optional_params.get("token") or get_secret_str("WATSONX_TOKEN"),
+        )
+        if token:
+            headers["Authorization"] = f"Bearer {token}"
+        elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"):
+            headers["Authorization"] = f"ZenApiKey {zen_api_key}"
+        else:
+            token = _generate_watsonx_token(api_key=api_key, token=token)
+            # build auth headers
+            headers["Authorization"] = f"Bearer {token}"
+        return {**default_headers, **headers}
+
+    def _get_base_url(self, api_base: Optional[str]) -> str:
+        url = (
+            api_base
+            or get_secret_str("WATSONX_API_BASE")  # consistent with 'AZURE_API_BASE'
+            or get_secret_str("WATSONX_URL")
+            or get_secret_str("WX_URL")
+            or get_secret_str("WML_URL")
+        )
+
+        if url is None:
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
+            )
+        return url
+
+    def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
+        api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
+        url = url + f"?version={api_version}"
+
+        return url
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return WatsonXAIError(
+            status_code=status_code, message=error_message, headers=headers
+        )
+
+    @staticmethod
+    def get_watsonx_credentials(
+        optional_params: dict, api_key: Optional[str], api_base: Optional[str]
+    ) -> WatsonXCredentials:
+        api_key = (
+            api_key
+            or optional_params.pop("apikey", None)
+            or get_secret_str("WATSONX_APIKEY")
+            or get_secret_str("WATSONX_API_KEY")
+            or get_secret_str("WX_API_KEY")
+        )
+
+        api_base = (
+            api_base
+            or optional_params.pop(
+                "url",
+                optional_params.pop("api_base", optional_params.pop("base_url", None)),
+            )
+            or get_secret_str("WATSONX_API_BASE")
+            or get_secret_str("WATSONX_URL")
+            or get_secret_str("WX_URL")
+            or get_secret_str("WML_URL")
+        )
+
+        wx_credentials = optional_params.pop(
+            "wx_credentials",
+            optional_params.pop(
+                "watsonx_credentials", None
+            ),  # follow {provider}_credentials, same as vertex ai
+        )
+
+        token: Optional[str] = None
+
+        if wx_credentials is not None:
+            api_base = wx_credentials.get("url", api_base)
+            api_key = wx_credentials.get(
+                "apikey", wx_credentials.get("api_key", api_key)
+            )
+            token = wx_credentials.get(
+                "token",
+                wx_credentials.get(
+                    "watsonx_token", None
+                ),  # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
+            )
+        if api_key is None or not isinstance(api_key, str):
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
+            )
+        if api_base is None or not isinstance(api_base, str):
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
+            )
+        return WatsonXCredentials(
+            api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
+        )
+
+    def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
+        payload: dict = {}
+        if model.startswith("deployment/"):
+            if api_params["space_id"] is None:
+                raise WatsonXAIError(
+                    status_code=401,
+                    message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
+                )
+            payload["space_id"] = api_params["space_id"]
+            return payload
+        payload["model_id"] = model
+        payload["project_id"] = api_params["project_id"]
+        return payload
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py
new file mode 100644
index 00000000..2a57ddcf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py
@@ -0,0 +1,3 @@
+"""
+Watsonx uses the llm_http_handler.py to handle the requests.
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py
new file mode 100644
index 00000000..f414354e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py
@@ -0,0 +1,391 @@
+import time
+from datetime import datetime
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    AsyncIterator,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Union,
+)
+
+import httpx
+
+from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
+from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
+from litellm.types.llms.watsonx import WatsonXAIEndpoint
+from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
+from litellm.utils import map_finish_reason
+
+from ...base_llm.chat.transformation import BaseConfig
+from ..common_utils import (
+    IBMWatsonXMixin,
+    WatsonXAIError,
+    _get_api_params,
+    convert_watsonx_messages_to_prompt,
+)
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
+    """
+    Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
+    (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
+
+    Supported params for all available watsonx.ai foundational models.
+
+    - `decoding_method` (str): One of "greedy" or "sample"
+
+    - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
+
+    - `max_new_tokens` (integer): Maximum length of the generated tokens.
+
+    - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
+
+    - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
+
+    - `stop_sequences` (string[]): list of strings to use as stop sequences.
+
+    - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
+
+    - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
+
+    - `repetition_penalty` (float): token repetition penalty during text generation.
+
+    - `truncate_input_tokens` (integer): Truncate input tokens to this length.
+
+    - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
+
+    - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
+
+    - `random_seed` (integer): Random seed for text generation.
+
+    - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
+
+    - `stream` (bool): If True, the model will return a stream of responses.
+    """
+
+    decoding_method: Optional[str] = "sample"
+    temperature: Optional[float] = None
+    max_new_tokens: Optional[int] = None  # litellm.max_tokens
+    min_new_tokens: Optional[int] = None
+    length_penalty: Optional[dict] = None  # e.g {"decay_factor": 2.5, "start_index": 5}
+    stop_sequences: Optional[List[str]] = None  # e.g ["}", ")", "."]
+    top_k: Optional[int] = None
+    top_p: Optional[float] = None
+    repetition_penalty: Optional[float] = None
+    truncate_input_tokens: Optional[int] = None
+    include_stop_sequences: Optional[bool] = False
+    return_options: Optional[Dict[str, bool]] = None
+    random_seed: Optional[int] = None  # e.g 42
+    moderations: Optional[dict] = None
+    stream: Optional[bool] = False
+
+    def __init__(
+        self,
+        decoding_method: Optional[str] = None,
+        temperature: Optional[float] = None,
+        max_new_tokens: Optional[int] = None,
+        min_new_tokens: Optional[int] = None,
+        length_penalty: Optional[dict] = None,
+        stop_sequences: Optional[List[str]] = None,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        repetition_penalty: Optional[float] = None,
+        truncate_input_tokens: Optional[int] = None,
+        include_stop_sequences: Optional[bool] = None,
+        return_options: Optional[dict] = None,
+        random_seed: Optional[int] = None,
+        moderations: Optional[dict] = None,
+        stream: Optional[bool] = None,
+        **kwargs,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return super().get_config()
+
+    def is_watsonx_text_param(self, param: str) -> bool:
+        """
+        Determine if user passed in a watsonx.ai text generation param
+        """
+        text_generation_params = [
+            "decoding_method",
+            "max_new_tokens",
+            "min_new_tokens",
+            "length_penalty",
+            "stop_sequences",
+            "top_k",
+            "repetition_penalty",
+            "truncate_input_tokens",
+            "include_stop_sequences",
+            "return_options",
+            "random_seed",
+            "moderations",
+            "decoding_method",
+            "min_tokens",
+        ]
+
+        return param in text_generation_params
+
+    def get_supported_openai_params(self, model: str):
+        return [
+            "temperature",  # equivalent to temperature
+            "max_tokens",  # equivalent to max_new_tokens
+            "top_p",  # equivalent to top_p
+            "frequency_penalty",  # equivalent to repetition_penalty
+            "stop",  # equivalent to stop_sequences
+            "seed",  # equivalent to random_seed
+            "stream",  # equivalent to stream
+        ]
+
+    def map_openai_params(
+        self,
+        non_default_params: Dict,
+        optional_params: Dict,
+        model: str,
+        drop_params: bool,
+    ) -> Dict:
+        extra_body = {}
+        for k, v in non_default_params.items():
+            if k == "max_tokens":
+                optional_params["max_new_tokens"] = v
+            elif k == "stream":
+                optional_params["stream"] = v
+            elif k == "temperature":
+                optional_params["temperature"] = v
+            elif k == "top_p":
+                optional_params["top_p"] = v
+            elif k == "frequency_penalty":
+                optional_params["repetition_penalty"] = v
+            elif k == "seed":
+                optional_params["random_seed"] = v
+            elif k == "stop":
+                optional_params["stop_sequences"] = v
+            elif k == "decoding_method":
+                extra_body["decoding_method"] = v
+            elif k == "min_tokens":
+                extra_body["min_new_tokens"] = v
+            elif k == "top_k":
+                extra_body["top_k"] = v
+            elif k == "truncate_input_tokens":
+                extra_body["truncate_input_tokens"] = v
+            elif k == "length_penalty":
+                extra_body["length_penalty"] = v
+            elif k == "time_limit":
+                extra_body["time_limit"] = v
+            elif k == "return_options":
+                extra_body["return_options"] = v
+
+        if extra_body:
+            optional_params["extra_body"] = extra_body
+        return optional_params
+
+    def get_mapped_special_auth_params(self) -> dict:
+        """
+        Common auth params across bedrock/vertex_ai/azure/watsonx
+        """
+        return {
+            "project": "watsonx_project",
+            "region_name": "watsonx_region_name",
+            "token": "watsonx_token",
+        }
+
+    def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+        mapped_params = self.get_mapped_special_auth_params()
+
+        for param, value in non_default_params.items():
+            if param in mapped_params:
+                optional_params[mapped_params[param]] = value
+        return optional_params
+
+    def get_eu_regions(self) -> List[str]:
+        """
+        Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
+        """
+        return [
+            "eu-de",
+            "eu-gb",
+        ]
+
+    def get_us_regions(self) -> List[str]:
+        """
+        Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
+        """
+        return [
+            "us-south",
+        ]
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        litellm_params: Dict,
+        headers: Dict,
+    ) -> Dict:
+        provider = model.split("/")[0]
+        prompt = convert_watsonx_messages_to_prompt(
+            model=model,
+            messages=messages,
+            provider=provider,
+            custom_prompt_dict={},
+        )
+        extra_body_params = optional_params.pop("extra_body", {})
+        optional_params.update(extra_body_params)
+        watsonx_api_params = _get_api_params(params=optional_params)
+
+        watsonx_auth_payload = self._prepare_payload(
+            model=model,
+            api_params=watsonx_api_params,
+        )
+
+        # init the payload to the text generation call
+        payload = {
+            "input": prompt,
+            "moderations": optional_params.pop("moderations", {}),
+            "parameters": optional_params,
+            **watsonx_auth_payload,
+        }
+
+        return payload
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: Dict,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        litellm_params: Dict,
+        encoding: str,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        ## LOGGING
+        logging_obj.post_call(
+            input=messages,
+            api_key="",
+            original_response=raw_response.text,
+        )
+
+        json_resp = raw_response.json()
+
+        if "results" not in json_resp:
+            raise WatsonXAIError(
+                status_code=500,
+                message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
+            )
+        if model_response is None:
+            model_response = ModelResponse(model=json_resp.get("model_id", None))
+        generated_text = json_resp["results"][0]["generated_text"]
+        prompt_tokens = json_resp["results"][0]["input_token_count"]
+        completion_tokens = json_resp["results"][0]["generated_token_count"]
+        model_response.choices[0].message.content = generated_text  # type: ignore
+        model_response.choices[0].finish_reason = map_finish_reason(
+            json_resp["results"][0]["stop_reason"]
+        )
+        if json_resp.get("created_at"):
+            model_response.created = int(
+                datetime.fromisoformat(json_resp["created_at"]).timestamp()
+            )
+        else:
+            model_response.created = int(time.time())
+        usage = Usage(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens,
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        url = self._get_base_url(api_base=api_base)
+        if model.startswith("deployment/"):
+            # deployment models are passed in as 'deployment/<deployment_id>'
+            deployment_id = "/".join(model.split("/")[1:])
+            endpoint = (
+                WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
+                if stream
+                else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
+            )
+            endpoint = endpoint.format(deployment_id=deployment_id)
+        else:
+            endpoint = (
+                WatsonXAIEndpoint.TEXT_GENERATION_STREAM
+                if stream
+                else WatsonXAIEndpoint.TEXT_GENERATION
+            )
+        url = url.rstrip("/") + endpoint
+
+        ## add api version
+        url = self._add_api_version_to_url(
+            url=url, api_version=optional_params.pop("api_version", None)
+        )
+        return url
+
+    def get_model_response_iterator(
+        self,
+        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
+        sync_stream: bool,
+        json_mode: Optional[bool] = False,
+    ):
+        return WatsonxTextCompletionResponseIterator(
+            streaming_response=streaming_response,
+            sync_stream=sync_stream,
+            json_mode=json_mode,
+        )
+
+
+class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator):
+    # def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
+    #     return self.chunk_parser(json.loads(str_line))
+
+    def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+        try:
+            results = chunk.get("results", [])
+            if len(results) > 0:
+                text = results[0].get("generated_text", "")
+                finish_reason = results[0].get("stop_reason")
+                is_finished = finish_reason != "not_finished"
+
+                return GenericStreamingChunk(
+                    text=text,
+                    is_finished=is_finished,
+                    finish_reason=finish_reason,
+                    usage=ChatCompletionUsageBlock(
+                        prompt_tokens=results[0].get("input_token_count", 0),
+                        completion_tokens=results[0].get("generated_token_count", 0),
+                        total_tokens=results[0].get("input_token_count", 0)
+                        + results[0].get("generated_token_count", 0),
+                    ),
+                )
+            return GenericStreamingChunk(
+                text="",
+                is_finished=False,
+                finish_reason="stop",
+                usage=None,
+            )
+        except Exception as e:
+            raise e
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py
new file mode 100644
index 00000000..359137ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py
@@ -0,0 +1,112 @@
+"""
+Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
+"""
+
+from typing import Optional
+
+import httpx
+
+from litellm.llms.base_llm.embedding.transformation import (
+    BaseEmbeddingConfig,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllEmbeddingInputValues
+from litellm.types.llms.watsonx import WatsonXAIEndpoint
+from litellm.types.utils import EmbeddingResponse, Usage
+
+from ..common_utils import IBMWatsonXMixin, _get_api_params
+
+
+class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
+    def get_supported_openai_params(self, model: str) -> list:
+        return []
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        return optional_params
+
+    def transform_embedding_request(
+        self,
+        model: str,
+        input: AllEmbeddingInputValues,
+        optional_params: dict,
+        headers: dict,
+    ) -> dict:
+        watsonx_api_params = _get_api_params(params=optional_params)
+        watsonx_auth_payload = self._prepare_payload(
+            model=model,
+            api_params=watsonx_api_params,
+        )
+
+        return {
+            "inputs": input,
+            "parameters": optional_params,
+            **watsonx_auth_payload,
+        }
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        url = self._get_base_url(api_base=api_base)
+        endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
+        if model.startswith("deployment/"):
+            deployment_id = "/".join(model.split("/")[1:])
+            endpoint = endpoint.format(deployment_id=deployment_id)
+        url = url.rstrip("/") + endpoint
+
+        ## add api version
+        url = self._add_api_version_to_url(
+            url=url, api_version=optional_params.pop("api_version", None)
+        )
+        return url
+
+    def transform_embedding_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: EmbeddingResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str],
+        request_data: dict,
+        optional_params: dict,
+        litellm_params: dict,
+    ) -> EmbeddingResponse:
+        logging_obj.post_call(
+            original_response=raw_response.text,
+        )
+        json_resp = raw_response.json()
+        if model_response is None:
+            model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
+        results = json_resp.get("results", [])
+        embedding_response = []
+        for idx, result in enumerate(results):
+            embedding_response.append(
+                {
+                    "object": "embedding",
+                    "index": idx,
+                    "embedding": result["embedding"],
+                }
+            )
+        model_response.object = "list"
+        model_response.data = embedding_response
+        input_tokens = json_resp.get("input_token_count", 0)
+        setattr(
+            model_response,
+            "usage",
+            Usage(
+                prompt_tokens=input_tokens,
+                completion_tokens=0,
+                total_tokens=input_tokens,
+            ),
+        )
+        return model_response