diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx')
6 files changed, 997 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py new file mode 100644 index 00000000..8ea19d41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/handler.py @@ -0,0 +1,90 @@ +from typing import Callable, Optional, Union + +import httpx + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.utils import CustomStreamingDecoder, ModelResponse + +from ...openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import _get_api_params +from .transformation import IBMWatsonXChatConfig + +watsonx_chat_transformation = IBMWatsonXChatConfig() + + +class WatsonXChatHandler(OpenAILikeChatHandler): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params: dict = {}, + headers: Optional[dict] = None, + logger_fn=None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + api_params = _get_api_params(params=optional_params) + + ## UPDATE HEADERS + headers = watsonx_chat_transformation.validate_environment( + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + ## UPDATE PAYLOAD (optional params) + watsonx_auth_payload = watsonx_chat_transformation._prepare_payload( + model=model, + api_params=api_params, + ) + optional_params.update(watsonx_auth_payload) + + ## GET API URL + api_base = watsonx_chat_transformation.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=optional_params.get("stream", False), + ) + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=True, + streaming_decoder=streaming_decoder, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py new file mode 100644 index 00000000..f253da6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/chat/transformation.py @@ -0,0 +1,110 @@ +""" +Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint. + +Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat +""" + +from typing import List, Optional, Tuple, Union + +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.watsonx import WatsonXAIEndpoint + +from ....utils import _remove_additional_properties, _remove_strict_from_schema +from ...openai.chat.gpt_transformation import OpenAIGPTConfig +from ..common_utils import IBMWatsonXMixin + + +class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): + + def get_supported_openai_params(self, model: str) -> List: + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + "tools", + "tool_choice", # equivalent to tool_choice + tool_choice_options + "logprobs", + "top_logprobs", + "n", + "presence_penalty", + "response_format", + ] + + def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool: + if tool_choice is None: + return False + if isinstance(tool_choice, str): + return tool_choice in ["auto", "none", "required"] + return False + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + ## TOOLS ## + _tools = non_default_params.pop("tools", None) + if _tools is not None: + # remove 'additionalProperties' from tools + _tools = _remove_additional_properties(_tools) + # remove 'strict' from tools + _tools = _remove_strict_from_schema(_tools) + if _tools is not None: + non_default_params["tools"] = _tools + + ## TOOL CHOICE ## + + _tool_choice = non_default_params.pop("tool_choice", None) + if self.is_tool_choice_option(_tool_choice): + optional_params["tool_choice_options"] = _tool_choice + elif _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice + return super().map_openai_params( + non_default_params, optional_params, model, drop_params + ) + + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore + dynamic_api_key = ( + api_key or get_secret_str("HOSTED_VLLM_API_KEY") or "" + ) # vllm does not require an api key + return api_base, dynamic_api_key + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + url = self._get_base_url(api_base=api_base) + if model.startswith("deployment/"): + deployment_id = "/".join(model.split("/")[1:]) + endpoint = ( + WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value + if stream + else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + endpoint = ( + WatsonXAIEndpoint.CHAT_STREAM.value + if stream + else WatsonXAIEndpoint.CHAT.value + ) + url = url.rstrip("/") + endpoint + + ## add api version + url = self._add_api_version_to_url( + url=url, api_version=optional_params.pop("api_version", None) + ) + return url diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py new file mode 100644 index 00000000..4916cd1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py @@ -0,0 +1,291 @@ +from typing import Dict, List, Optional, Union, cast + +import httpx + +import litellm +from litellm import verbose_logger +from litellm.caching import InMemoryCache +from litellm.litellm_core_utils.prompt_templates import factory as ptf +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials + + +class WatsonXAIError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[Dict, httpx.Headers]] = None, + ): + super().__init__(status_code=status_code, message=message, headers=headers) + + +iam_token_cache = InMemoryCache() + + +def get_watsonx_iam_url(): + return ( + get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token" + ) + + +def generate_iam_token(api_key=None, **params) -> str: + result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore + + if result is None: + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + if api_key is None: + api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") + if api_key is None: + raise ValueError("API key is required") + headers["Accept"] = "application/json" + data = { + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + } + iam_token_url = get_watsonx_iam_url() + verbose_logger.debug( + "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s", + iam_token_url, + headers, + data, + ) + response = litellm.module_level_client.post( + url=iam_token_url, data=data, headers=headers + ) + response.raise_for_status() + json_data = response.json() + + result = json_data["access_token"] + iam_token_cache.set_cache( + key=api_key, + value=result, + ttl=json_data["expires_in"] - 10, # leave some buffer + ) + + return cast(str, result) + + +def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str: + if token is not None: + return token + token = generate_iam_token(api_key) + return token + + +def _get_api_params( + params: dict, +) -> WatsonXAPIParams: + """ + Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. + """ + # Load auth variables from params + project_id = params.pop( + "project_id", params.pop("watsonx_project", None) + ) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params + space_id = params.pop("space_id", None) # watsonx.ai deployment space_id + region_name = params.pop("region_name", params.pop("region", None)) + if region_name is None: + region_name = params.pop( + "watsonx_region_name", params.pop("watsonx_region", None) + ) # consistent with how vertex ai + aws regions are accepted + + # Load auth variables from environment variables + if project_id is None: + project_id = ( + get_secret_str("WATSONX_PROJECT_ID") + or get_secret_str("WX_PROJECT_ID") + or get_secret_str("PROJECT_ID") + ) + if region_name is None: + region_name = ( + get_secret_str("WATSONX_REGION") + or get_secret_str("WX_REGION") + or get_secret_str("REGION") + ) + if space_id is None: + space_id = ( + get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret_str("WATSONX_SPACE_ID") + or get_secret_str("WX_SPACE_ID") + or get_secret_str("SPACE_ID") + ) + + if project_id is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + ) + + return WatsonXAPIParams( + project_id=project_id, + space_id=space_id, + region_name=region_name, + ) + + +def convert_watsonx_messages_to_prompt( + model: str, + messages: List[AllMessageValues], + provider: str, + custom_prompt_dict: Dict, +) -> str: + # handle anthropic prompts and amazon titan prompts + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_dict = custom_prompt_dict[model] + prompt = ptf.custom_prompt( + messages=messages, + role_dict=model_prompt_dict.get( + "role_dict", model_prompt_dict.get("roles") + ), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), + bos_token=model_prompt_dict.get("bos_token", ""), + eos_token=model_prompt_dict.get("eos_token", ""), + ) + return prompt + elif provider == "ibm-mistralai": + prompt = ptf.mistral_instruct_pt(messages=messages) + else: + prompt: str = ptf.prompt_factory( # type: ignore + model=model, messages=messages, custom_llm_provider="watsonx" + ) + return prompt + + +# Mixin class for shared IBM Watson X functionality +class IBMWatsonXMixin: + def validate_environment( + self, + headers: Dict, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Dict: + default_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + if "Authorization" in headers: + return {**default_headers, **headers} + token = cast( + Optional[str], + optional_params.get("token") or get_secret_str("WATSONX_TOKEN"), + ) + if token: + headers["Authorization"] = f"Bearer {token}" + elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"): + headers["Authorization"] = f"ZenApiKey {zen_api_key}" + else: + token = _generate_watsonx_token(api_key=api_key, token=token) + # build auth headers + headers["Authorization"] = f"Bearer {token}" + return {**default_headers, **headers} + + def _get_base_url(self, api_base: Optional[str]) -> str: + url = ( + api_base + or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' + or get_secret_str("WATSONX_URL") + or get_secret_str("WX_URL") + or get_secret_str("WML_URL") + ) + + if url is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", + ) + return url + + def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str: + api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION + url = url + f"?version={api_version}" + + return url + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] + ) -> BaseLLMException: + return WatsonXAIError( + status_code=status_code, message=error_message, headers=headers + ) + + @staticmethod + def get_watsonx_credentials( + optional_params: dict, api_key: Optional[str], api_base: Optional[str] + ) -> WatsonXCredentials: + api_key = ( + api_key + or optional_params.pop("apikey", None) + or get_secret_str("WATSONX_APIKEY") + or get_secret_str("WATSONX_API_KEY") + or get_secret_str("WX_API_KEY") + ) + + api_base = ( + api_base + or optional_params.pop( + "url", + optional_params.pop("api_base", optional_params.pop("base_url", None)), + ) + or get_secret_str("WATSONX_API_BASE") + or get_secret_str("WATSONX_URL") + or get_secret_str("WX_URL") + or get_secret_str("WML_URL") + ) + + wx_credentials = optional_params.pop( + "wx_credentials", + optional_params.pop( + "watsonx_credentials", None + ), # follow {provider}_credentials, same as vertex ai + ) + + token: Optional[str] = None + + if wx_credentials is not None: + api_base = wx_credentials.get("url", api_base) + api_key = wx_credentials.get( + "apikey", wx_credentials.get("api_key", api_key) + ) + token = wx_credentials.get( + "token", + wx_credentials.get( + "watsonx_token", None + ), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..' + ) + if api_key is None or not isinstance(api_key, str): + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.", + ) + if api_base is None or not isinstance(api_base, str): + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", + ) + return WatsonXCredentials( + api_key=api_key, api_base=api_base, token=cast(Optional[str], token) + ) + + def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: + payload: dict = {} + if model.startswith("deployment/"): + if api_params["space_id"] is None: + raise WatsonXAIError( + status_code=401, + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + payload["space_id"] = api_params["space_id"] + return payload + payload["model_id"] = model + payload["project_id"] = api_params["project_id"] + return payload diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py new file mode 100644 index 00000000..2a57ddcf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/handler.py @@ -0,0 +1,3 @@ +""" +Watsonx uses the llm_http_handler.py to handle the requests. +""" diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py new file mode 100644 index 00000000..f414354e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/completion/transformation.py @@ -0,0 +1,391 @@ +import time +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Union, +) + +import httpx + +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock +from litellm.types.llms.watsonx import WatsonXAIEndpoint +from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage +from litellm.utils import map_finish_reason + +from ...base_llm.chat.transformation import BaseConfig +from ..common_utils import ( + IBMWatsonXMixin, + WatsonXAIError, + _get_api_params, + convert_watsonx_messages_to_prompt, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig): + """ + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) + + Supported params for all available watsonx.ai foundational models. + + - `decoding_method` (str): One of "greedy" or "sample" + + - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. + + - `max_new_tokens` (integer): Maximum length of the generated tokens. + + - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + + - `stop_sequences` (string[]): list of strings to use as stop sequences. + + - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + + - `repetition_penalty` (float): token repetition penalty during text generation. + + - `truncate_input_tokens` (integer): Truncate input tokens to this length. + + - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. + + - `random_seed` (integer): Random seed for text generation. + + - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. + + - `stream` (bool): If True, the model will return a stream of responses. + """ + + decoding_method: Optional[str] = "sample" + temperature: Optional[float] = None + max_new_tokens: Optional[int] = None # litellm.max_tokens + min_new_tokens: Optional[int] = None + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + truncate_input_tokens: Optional[int] = None + include_stop_sequences: Optional[bool] = False + return_options: Optional[Dict[str, bool]] = None + random_seed: Optional[int] = None # e.g 42 + moderations: Optional[dict] = None + stream: Optional[bool] = False + + def __init__( + self, + decoding_method: Optional[str] = None, + temperature: Optional[float] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + length_penalty: Optional[dict] = None, + stop_sequences: Optional[List[str]] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + truncate_input_tokens: Optional[int] = None, + include_stop_sequences: Optional[bool] = None, + return_options: Optional[dict] = None, + random_seed: Optional[int] = None, + moderations: Optional[dict] = None, + stream: Optional[bool] = None, + **kwargs, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def is_watsonx_text_param(self, param: str) -> bool: + """ + Determine if user passed in a watsonx.ai text generation param + """ + text_generation_params = [ + "decoding_method", + "max_new_tokens", + "min_new_tokens", + "length_penalty", + "stop_sequences", + "top_k", + "repetition_penalty", + "truncate_input_tokens", + "include_stop_sequences", + "return_options", + "random_seed", + "moderations", + "decoding_method", + "min_tokens", + ] + + return param in text_generation_params + + def get_supported_openai_params(self, model: str): + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + ] + + def map_openai_params( + self, + non_default_params: Dict, + optional_params: Dict, + model: str, + drop_params: bool, + ) -> Dict: + extra_body = {} + for k, v in non_default_params.items(): + if k == "max_tokens": + optional_params["max_new_tokens"] = v + elif k == "stream": + optional_params["stream"] = v + elif k == "temperature": + optional_params["temperature"] = v + elif k == "top_p": + optional_params["top_p"] = v + elif k == "frequency_penalty": + optional_params["repetition_penalty"] = v + elif k == "seed": + optional_params["random_seed"] = v + elif k == "stop": + optional_params["stop_sequences"] = v + elif k == "decoding_method": + extra_body["decoding_method"] = v + elif k == "min_tokens": + extra_body["min_new_tokens"] = v + elif k == "top_k": + extra_body["top_k"] = v + elif k == "truncate_input_tokens": + extra_body["truncate_input_tokens"] = v + elif k == "length_penalty": + extra_body["length_penalty"] = v + elif k == "time_limit": + extra_body["time_limit"] = v + elif k == "return_options": + extra_body["return_options"] = v + + if extra_body: + optional_params["extra_body"] = extra_body + return optional_params + + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return { + "project": "watsonx_project", + "region_name": "watsonx_region_name", + "token": "watsonx_token", + } + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability + """ + return [ + "eu-de", + "eu-gb", + ] + + def get_us_regions(self) -> List[str]: + """ + Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability + """ + return [ + "us-south", + ] + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + headers: Dict, + ) -> Dict: + provider = model.split("/")[0] + prompt = convert_watsonx_messages_to_prompt( + model=model, + messages=messages, + provider=provider, + custom_prompt_dict={}, + ) + extra_body_params = optional_params.pop("extra_body", {}) + optional_params.update(extra_body_params) + watsonx_api_params = _get_api_params(params=optional_params) + + watsonx_auth_payload = self._prepare_payload( + model=model, + api_params=watsonx_api_params, + ) + + # init the payload to the text generation call + payload = { + "input": prompt, + "moderations": optional_params.pop("moderations", {}), + "parameters": optional_params, + **watsonx_auth_payload, + } + + return payload + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=raw_response.text, + ) + + json_resp = raw_response.json() + + if "results" not in json_resp: + raise WatsonXAIError( + status_code=500, + message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", + ) + if model_response is None: + model_response = ModelResponse(model=json_resp.get("model_id", None)) + generated_text = json_resp["results"][0]["generated_text"] + prompt_tokens = json_resp["results"][0]["input_token_count"] + completion_tokens = json_resp["results"][0]["generated_token_count"] + model_response.choices[0].message.content = generated_text # type: ignore + model_response.choices[0].finish_reason = map_finish_reason( + json_resp["results"][0]["stop_reason"] + ) + if json_resp.get("created_at"): + model_response.created = int( + datetime.fromisoformat(json_resp["created_at"]).timestamp() + ) + else: + model_response.created = int(time.time()) + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + url = self._get_base_url(api_base=api_base) + if model.startswith("deployment/"): + # deployment models are passed in as 'deployment/<deployment_id>' + deployment_id = "/".join(model.split("/")[1:]) + endpoint = ( + WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value + if stream + else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + endpoint = ( + WatsonXAIEndpoint.TEXT_GENERATION_STREAM + if stream + else WatsonXAIEndpoint.TEXT_GENERATION + ) + url = url.rstrip("/") + endpoint + + ## add api version + url = self._add_api_version_to_url( + url=url, api_version=optional_params.pop("api_version", None) + ) + return url + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return WatsonxTextCompletionResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator): + # def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk: + # return self.chunk_parser(json.loads(str_line)) + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + results = chunk.get("results", []) + if len(results) > 0: + text = results[0].get("generated_text", "") + finish_reason = results[0].get("stop_reason") + is_finished = finish_reason != "not_finished" + + return GenericStreamingChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + usage=ChatCompletionUsageBlock( + prompt_tokens=results[0].get("input_token_count", 0), + completion_tokens=results[0].get("generated_token_count", 0), + total_tokens=results[0].get("input_token_count", 0) + + results[0].get("generated_token_count", 0), + ), + ) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="stop", + usage=None, + ) + except Exception as e: + raise e diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py new file mode 100644 index 00000000..359137ee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py @@ -0,0 +1,112 @@ +""" +Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route. +""" + +from typing import Optional + +import httpx + +from litellm.llms.base_llm.embedding.transformation import ( + BaseEmbeddingConfig, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import AllEmbeddingInputValues +from litellm.types.llms.watsonx import WatsonXAIEndpoint +from litellm.types.utils import EmbeddingResponse, Usage + +from ..common_utils import IBMWatsonXMixin, _get_api_params + + +class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig): + def get_supported_openai_params(self, model: str) -> list: + return [] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + return optional_params + + def transform_embedding_request( + self, + model: str, + input: AllEmbeddingInputValues, + optional_params: dict, + headers: dict, + ) -> dict: + watsonx_api_params = _get_api_params(params=optional_params) + watsonx_auth_payload = self._prepare_payload( + model=model, + api_params=watsonx_api_params, + ) + + return { + "inputs": input, + "parameters": optional_params, + **watsonx_auth_payload, + } + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + url = self._get_base_url(api_base=api_base) + endpoint = WatsonXAIEndpoint.EMBEDDINGS.value + if model.startswith("deployment/"): + deployment_id = "/".join(model.split("/")[1:]) + endpoint = endpoint.format(deployment_id=deployment_id) + url = url.rstrip("/") + endpoint + + ## add api version + url = self._add_api_version_to_url( + url=url, api_version=optional_params.pop("api_version", None) + ) + return url + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str], + request_data: dict, + optional_params: dict, + litellm_params: dict, + ) -> EmbeddingResponse: + logging_obj.post_call( + original_response=raw_response.text, + ) + json_resp = raw_response.json() + if model_response is None: + model_response = EmbeddingResponse(model=json_resp.get("model_id", None)) + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + { + "object": "embedding", + "index": idx, + "embedding": result["embedding"], + } + ) + model_response.object = "list" + model_response.data = embedding_response + input_tokens = json_resp.get("input_token_count", 0) + setattr( + model_response, + "usage", + Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + ), + ) + return model_response |