about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py291
1 files changed, 291 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
new file mode 100644
index 00000000..4916cd1c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/common_utils.py
@@ -0,0 +1,291 @@
+from typing import Dict, List, Optional, Union, cast
+
+import httpx
+
+import litellm
+from litellm import verbose_logger
+from litellm.caching import InMemoryCache
+from litellm.litellm_core_utils.prompt_templates import factory as ptf
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
+
+
+class WatsonXAIError(BaseLLMException):
+    def __init__(
+        self,
+        status_code: int,
+        message: str,
+        headers: Optional[Union[Dict, httpx.Headers]] = None,
+    ):
+        super().__init__(status_code=status_code, message=message, headers=headers)
+
+
+iam_token_cache = InMemoryCache()
+
+
+def get_watsonx_iam_url():
+    return (
+        get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
+    )
+
+
+def generate_iam_token(api_key=None, **params) -> str:
+    result: Optional[str] = iam_token_cache.get_cache(api_key)  # type: ignore
+
+    if result is None:
+        headers = {}
+        headers["Content-Type"] = "application/x-www-form-urlencoded"
+        if api_key is None:
+            api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
+        if api_key is None:
+            raise ValueError("API key is required")
+        headers["Accept"] = "application/json"
+        data = {
+            "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
+            "apikey": api_key,
+        }
+        iam_token_url = get_watsonx_iam_url()
+        verbose_logger.debug(
+            "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
+            iam_token_url,
+            headers,
+            data,
+        )
+        response = litellm.module_level_client.post(
+            url=iam_token_url, data=data, headers=headers
+        )
+        response.raise_for_status()
+        json_data = response.json()
+
+        result = json_data["access_token"]
+        iam_token_cache.set_cache(
+            key=api_key,
+            value=result,
+            ttl=json_data["expires_in"] - 10,  # leave some buffer
+        )
+
+    return cast(str, result)
+
+
+def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
+    if token is not None:
+        return token
+    token = generate_iam_token(api_key)
+    return token
+
+
+def _get_api_params(
+    params: dict,
+) -> WatsonXAPIParams:
+    """
+    Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
+    """
+    # Load auth variables from params
+    project_id = params.pop(
+        "project_id", params.pop("watsonx_project", None)
+    )  # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
+    space_id = params.pop("space_id", None)  # watsonx.ai deployment space_id
+    region_name = params.pop("region_name", params.pop("region", None))
+    if region_name is None:
+        region_name = params.pop(
+            "watsonx_region_name", params.pop("watsonx_region", None)
+        )  # consistent with how vertex ai + aws regions are accepted
+
+    # Load auth variables from environment variables
+    if project_id is None:
+        project_id = (
+            get_secret_str("WATSONX_PROJECT_ID")
+            or get_secret_str("WX_PROJECT_ID")
+            or get_secret_str("PROJECT_ID")
+        )
+    if region_name is None:
+        region_name = (
+            get_secret_str("WATSONX_REGION")
+            or get_secret_str("WX_REGION")
+            or get_secret_str("REGION")
+        )
+    if space_id is None:
+        space_id = (
+            get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
+            or get_secret_str("WATSONX_SPACE_ID")
+            or get_secret_str("WX_SPACE_ID")
+            or get_secret_str("SPACE_ID")
+        )
+
+    if project_id is None:
+        raise WatsonXAIError(
+            status_code=401,
+            message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
+        )
+
+    return WatsonXAPIParams(
+        project_id=project_id,
+        space_id=space_id,
+        region_name=region_name,
+    )
+
+
+def convert_watsonx_messages_to_prompt(
+    model: str,
+    messages: List[AllMessageValues],
+    provider: str,
+    custom_prompt_dict: Dict,
+) -> str:
+    # handle anthropic prompts and amazon titan prompts
+    if model in custom_prompt_dict:
+        # check if the model has a registered custom prompt
+        model_prompt_dict = custom_prompt_dict[model]
+        prompt = ptf.custom_prompt(
+            messages=messages,
+            role_dict=model_prompt_dict.get(
+                "role_dict", model_prompt_dict.get("roles")
+            ),
+            initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
+            final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
+            bos_token=model_prompt_dict.get("bos_token", ""),
+            eos_token=model_prompt_dict.get("eos_token", ""),
+        )
+        return prompt
+    elif provider == "ibm-mistralai":
+        prompt = ptf.mistral_instruct_pt(messages=messages)
+    else:
+        prompt: str = ptf.prompt_factory(  # type: ignore
+            model=model, messages=messages, custom_llm_provider="watsonx"
+        )
+    return prompt
+
+
+# Mixin class for shared IBM Watson X functionality
+class IBMWatsonXMixin:
+    def validate_environment(
+        self,
+        headers: Dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: Dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> Dict:
+        default_headers = {
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+        }
+
+        if "Authorization" in headers:
+            return {**default_headers, **headers}
+        token = cast(
+            Optional[str],
+            optional_params.get("token") or get_secret_str("WATSONX_TOKEN"),
+        )
+        if token:
+            headers["Authorization"] = f"Bearer {token}"
+        elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"):
+            headers["Authorization"] = f"ZenApiKey {zen_api_key}"
+        else:
+            token = _generate_watsonx_token(api_key=api_key, token=token)
+            # build auth headers
+            headers["Authorization"] = f"Bearer {token}"
+        return {**default_headers, **headers}
+
+    def _get_base_url(self, api_base: Optional[str]) -> str:
+        url = (
+            api_base
+            or get_secret_str("WATSONX_API_BASE")  # consistent with 'AZURE_API_BASE'
+            or get_secret_str("WATSONX_URL")
+            or get_secret_str("WX_URL")
+            or get_secret_str("WML_URL")
+        )
+
+        if url is None:
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
+            )
+        return url
+
+    def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
+        api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
+        url = url + f"?version={api_version}"
+
+        return url
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return WatsonXAIError(
+            status_code=status_code, message=error_message, headers=headers
+        )
+
+    @staticmethod
+    def get_watsonx_credentials(
+        optional_params: dict, api_key: Optional[str], api_base: Optional[str]
+    ) -> WatsonXCredentials:
+        api_key = (
+            api_key
+            or optional_params.pop("apikey", None)
+            or get_secret_str("WATSONX_APIKEY")
+            or get_secret_str("WATSONX_API_KEY")
+            or get_secret_str("WX_API_KEY")
+        )
+
+        api_base = (
+            api_base
+            or optional_params.pop(
+                "url",
+                optional_params.pop("api_base", optional_params.pop("base_url", None)),
+            )
+            or get_secret_str("WATSONX_API_BASE")
+            or get_secret_str("WATSONX_URL")
+            or get_secret_str("WX_URL")
+            or get_secret_str("WML_URL")
+        )
+
+        wx_credentials = optional_params.pop(
+            "wx_credentials",
+            optional_params.pop(
+                "watsonx_credentials", None
+            ),  # follow {provider}_credentials, same as vertex ai
+        )
+
+        token: Optional[str] = None
+
+        if wx_credentials is not None:
+            api_base = wx_credentials.get("url", api_base)
+            api_key = wx_credentials.get(
+                "apikey", wx_credentials.get("api_key", api_key)
+            )
+            token = wx_credentials.get(
+                "token",
+                wx_credentials.get(
+                    "watsonx_token", None
+                ),  # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
+            )
+        if api_key is None or not isinstance(api_key, str):
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
+            )
+        if api_base is None or not isinstance(api_base, str):
+            raise WatsonXAIError(
+                status_code=401,
+                message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
+            )
+        return WatsonXCredentials(
+            api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
+        )
+
+    def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
+        payload: dict = {}
+        if model.startswith("deployment/"):
+            if api_params["space_id"] is None:
+                raise WatsonXAIError(
+                    status_code=401,
+                    message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
+                )
+            payload["space_id"] = api_params["space_id"]
+            return payload
+        payload["model_id"] = model
+        payload["project_id"] = api_params["project_id"]
+        return payload