diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py | 291 |
1 files changed, 291 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py new file mode 100644 index 00000000..4916cd1c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py @@ -0,0 +1,291 @@ +from typing import Dict, List, Optional, Union, cast + +import httpx + +import litellm +from litellm import verbose_logger +from litellm.caching import InMemoryCache +from litellm.litellm_core_utils.prompt_templates import factory as ptf +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials + + +class WatsonXAIError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[Dict, httpx.Headers]] = None, + ): + super().__init__(status_code=status_code, message=message, headers=headers) + + +iam_token_cache = InMemoryCache() + + +def get_watsonx_iam_url(): + return ( + get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token" + ) + + +def generate_iam_token(api_key=None, **params) -> str: + result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore + + if result is None: + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + if api_key is None: + api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") + if api_key is None: + raise ValueError("API key is required") + headers["Accept"] = "application/json" + data = { + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + } + iam_token_url = get_watsonx_iam_url() + verbose_logger.debug( + "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s", + iam_token_url, + headers, + data, + ) + response = litellm.module_level_client.post( + url=iam_token_url, data=data, headers=headers + ) + response.raise_for_status() + json_data = response.json() + + result = json_data["access_token"] + iam_token_cache.set_cache( + key=api_key, + value=result, + ttl=json_data["expires_in"] - 10, # leave some buffer + ) + + return cast(str, result) + + +def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str: + if token is not None: + return token + token = generate_iam_token(api_key) + return token + + +def _get_api_params( + params: dict, +) -> WatsonXAPIParams: + """ + Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. + """ + # Load auth variables from params + project_id = params.pop( + "project_id", params.pop("watsonx_project", None) + ) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params + space_id = params.pop("space_id", None) # watsonx.ai deployment space_id + region_name = params.pop("region_name", params.pop("region", None)) + if region_name is None: + region_name = params.pop( + "watsonx_region_name", params.pop("watsonx_region", None) + ) # consistent with how vertex ai + aws regions are accepted + + # Load auth variables from environment variables + if project_id is None: + project_id = ( + get_secret_str("WATSONX_PROJECT_ID") + or get_secret_str("WX_PROJECT_ID") + or get_secret_str("PROJECT_ID") + ) + if region_name is None: + region_name = ( + get_secret_str("WATSONX_REGION") + or get_secret_str("WX_REGION") + or get_secret_str("REGION") + ) + if space_id is None: + space_id = ( + get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret_str("WATSONX_SPACE_ID") + or get_secret_str("WX_SPACE_ID") + or get_secret_str("SPACE_ID") + ) + + if project_id is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + ) + + return WatsonXAPIParams( + project_id=project_id, + space_id=space_id, + region_name=region_name, + ) + + +def convert_watsonx_messages_to_prompt( + model: str, + messages: List[AllMessageValues], + provider: str, + custom_prompt_dict: Dict, +) -> str: + # handle anthropic prompts and amazon titan prompts + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_dict = custom_prompt_dict[model] + prompt = ptf.custom_prompt( + messages=messages, + role_dict=model_prompt_dict.get( + "role_dict", model_prompt_dict.get("roles") + ), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), + bos_token=model_prompt_dict.get("bos_token", ""), + eos_token=model_prompt_dict.get("eos_token", ""), + ) + return prompt + elif provider == "ibm-mistralai": + prompt = ptf.mistral_instruct_pt(messages=messages) + else: + prompt: str = ptf.prompt_factory( # type: ignore + model=model, messages=messages, custom_llm_provider="watsonx" + ) + return prompt + + +# Mixin class for shared IBM Watson X functionality +class IBMWatsonXMixin: + def validate_environment( + self, + headers: Dict, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Dict: + default_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + if "Authorization" in headers: + return {**default_headers, **headers} + token = cast( + Optional[str], + optional_params.get("token") or get_secret_str("WATSONX_TOKEN"), + ) + if token: + headers["Authorization"] = f"Bearer {token}" + elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"): + headers["Authorization"] = f"ZenApiKey {zen_api_key}" + else: + token = _generate_watsonx_token(api_key=api_key, token=token) + # build auth headers + headers["Authorization"] = f"Bearer {token}" + return {**default_headers, **headers} + + def _get_base_url(self, api_base: Optional[str]) -> str: + url = ( + api_base + or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' + or get_secret_str("WATSONX_URL") + or get_secret_str("WX_URL") + or get_secret_str("WML_URL") + ) + + if url is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", + ) + return url + + def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str: + api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION + url = url + f"?version={api_version}" + + return url + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] + ) -> BaseLLMException: + return WatsonXAIError( + status_code=status_code, message=error_message, headers=headers + ) + + @staticmethod + def get_watsonx_credentials( + optional_params: dict, api_key: Optional[str], api_base: Optional[str] + ) -> WatsonXCredentials: + api_key = ( + api_key + or optional_params.pop("apikey", None) + or get_secret_str("WATSONX_APIKEY") + or get_secret_str("WATSONX_API_KEY") + or get_secret_str("WX_API_KEY") + ) + + api_base = ( + api_base + or optional_params.pop( + "url", + optional_params.pop("api_base", optional_params.pop("base_url", None)), + ) + or get_secret_str("WATSONX_API_BASE") + or get_secret_str("WATSONX_URL") + or get_secret_str("WX_URL") + or get_secret_str("WML_URL") + ) + + wx_credentials = optional_params.pop( + "wx_credentials", + optional_params.pop( + "watsonx_credentials", None + ), # follow {provider}_credentials, same as vertex ai + ) + + token: Optional[str] = None + + if wx_credentials is not None: + api_base = wx_credentials.get("url", api_base) + api_key = wx_credentials.get( + "apikey", wx_credentials.get("api_key", api_key) + ) + token = wx_credentials.get( + "token", + wx_credentials.get( + "watsonx_token", None + ), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..' + ) + if api_key is None or not isinstance(api_key, str): + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.", + ) + if api_base is None or not isinstance(api_base, str): + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.", + ) + return WatsonXCredentials( + api_key=api_key, api_base=api_base, token=cast(Optional[str], token) + ) + + def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: + payload: dict = {} + if model.startswith("deployment/"): + if api_params["space_id"] is None: + raise WatsonXAIError( + status_code=401, + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + payload["space_id"] = api_params["space_id"] + return payload + payload["model_id"] = model + payload["project_id"] = api_params["project_id"] + return payload |