about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/handler.py179
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/transformation.py26
2 files changed, 205 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/handler.py
new file mode 100644
index 00000000..c827a8a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/handler.py
@@ -0,0 +1,179 @@
+import json
+from copy import deepcopy
+from typing import Callable, Optional, Union
+
+import httpx
+
+from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.utils import ModelResponse, get_secret
+
+from ..common_utils import AWSEventStreamDecoder
+from .transformation import SagemakerChatConfig
+
+
+class SagemakerChatHandler(BaseAWSLLM):
+
+    def _load_credentials(
+        self,
+        optional_params: dict,
+    ):
+        try:
+            from botocore.credentials import Credentials
+        except ImportError:
+            raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+        ## CREDENTIALS ##
+        # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
+        aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
+        aws_access_key_id = optional_params.pop("aws_access_key_id", None)
+        aws_session_token = optional_params.pop("aws_session_token", None)
+        aws_region_name = optional_params.pop("aws_region_name", None)
+        aws_role_name = optional_params.pop("aws_role_name", None)
+        aws_session_name = optional_params.pop("aws_session_name", None)
+        aws_profile_name = optional_params.pop("aws_profile_name", None)
+        optional_params.pop(
+            "aws_bedrock_runtime_endpoint", None
+        )  # https://bedrock-runtime.{region_name}.amazonaws.com
+        aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
+        aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
+
+        ### SET REGION NAME ###
+        if aws_region_name is None:
+            # check env #
+            litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+
+            if litellm_aws_region_name is not None and isinstance(
+                litellm_aws_region_name, str
+            ):
+                aws_region_name = litellm_aws_region_name
+
+            standard_aws_region_name = get_secret("AWS_REGION", None)
+            if standard_aws_region_name is not None and isinstance(
+                standard_aws_region_name, str
+            ):
+                aws_region_name = standard_aws_region_name
+
+            if aws_region_name is None:
+                aws_region_name = "us-west-2"
+
+        credentials: Credentials = self.get_credentials(
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=aws_secret_access_key,
+            aws_session_token=aws_session_token,
+            aws_region_name=aws_region_name,
+            aws_session_name=aws_session_name,
+            aws_profile_name=aws_profile_name,
+            aws_role_name=aws_role_name,
+            aws_web_identity_token=aws_web_identity_token,
+            aws_sts_endpoint=aws_sts_endpoint,
+        )
+        return credentials, aws_region_name
+
+    def _prepare_request(
+        self,
+        credentials,
+        model: str,
+        data: dict,
+        optional_params: dict,
+        aws_region_name: str,
+        extra_headers: Optional[dict] = None,
+    ):
+        try:
+            from botocore.auth import SigV4Auth
+            from botocore.awsrequest import AWSRequest
+        except ImportError:
+            raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+        sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
+        if optional_params.get("stream") is True:
+            api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
+        else:
+            api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations"
+
+        sagemaker_base_url = optional_params.get("sagemaker_base_url", None)
+        if sagemaker_base_url is not None:
+            api_base = sagemaker_base_url
+
+        encoded_data = json.dumps(data).encode("utf-8")
+        headers = {"Content-Type": "application/json"}
+        if extra_headers is not None:
+            headers = {"Content-Type": "application/json", **extra_headers}
+        request = AWSRequest(
+            method="POST", url=api_base, data=encoded_data, headers=headers
+        )
+        sigv4.add_auth(request)
+        if (
+            extra_headers is not None and "Authorization" in extra_headers
+        ):  # prevent sigv4 from overwriting the auth header
+            request.headers["Authorization"] = extra_headers["Authorization"]
+
+        prepped_request = request.prepare()
+
+        return prepped_request
+
+    def completion(
+        self,
+        model: str,
+        messages: list,
+        model_response: ModelResponse,
+        print_verbose: Callable,
+        encoding,
+        logging_obj,
+        optional_params: dict,
+        litellm_params: dict,
+        timeout: Optional[Union[float, httpx.Timeout]] = None,
+        custom_prompt_dict={},
+        logger_fn=None,
+        acompletion: bool = False,
+        headers: dict = {},
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+    ):
+
+        # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
+        credentials, aws_region_name = self._load_credentials(optional_params)
+        inference_params = deepcopy(optional_params)
+        stream = inference_params.pop("stream", None)
+
+        from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
+
+        openai_like_chat_completions = OpenAILikeChatHandler()
+        inference_params["stream"] = True if stream is True else False
+        _data = SagemakerChatConfig().transform_request(
+            model=model,
+            messages=messages,
+            optional_params=inference_params,
+            litellm_params=litellm_params,
+            headers=headers,
+        )
+
+        prepared_request = self._prepare_request(
+            model=model,
+            data=_data,
+            optional_params=optional_params,
+            credentials=credentials,
+            aws_region_name=aws_region_name,
+        )
+
+        custom_stream_decoder = AWSEventStreamDecoder(model="", is_messages_api=True)
+
+        return openai_like_chat_completions.completion(
+            model=model,
+            messages=messages,
+            api_base=prepared_request.url,
+            api_key=None,
+            custom_prompt_dict=custom_prompt_dict,
+            model_response=model_response,
+            print_verbose=print_verbose,
+            logging_obj=logging_obj,
+            optional_params=inference_params,
+            acompletion=acompletion,
+            litellm_params=litellm_params,
+            logger_fn=logger_fn,
+            timeout=timeout,
+            encoding=encoding,
+            headers=prepared_request.headers,  # type: ignore
+            custom_endpoint=True,
+            custom_llm_provider="sagemaker_chat",
+            streaming_decoder=custom_stream_decoder,  # type: ignore
+            client=client,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/transformation.py
new file mode 100644
index 00000000..42c7e0d5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/chat/transformation.py
@@ -0,0 +1,26 @@
+"""
+Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invocations` API
+
+Called if Sagemaker endpoint supports HF Messages API.
+
+LiteLLM Docs: https://docs.litellm.ai/docs/providers/aws_sagemaker#sagemaker-messages-api
+Huggingface Docs: https://huggingface.co/docs/text-generation-inference/en/messages_api
+"""
+
+from typing import Union
+
+from httpx._models import Headers
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+
+from ...openai.chat.gpt_transformation import OpenAIGPTConfig
+from ..common_utils import SagemakerError
+
+
+class SagemakerChatConfig(OpenAIGPTConfig):
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, Headers]
+    ) -> BaseLLMException:
+        return SagemakerError(
+            status_code=status_code, message=error_message, headers=headers
+        )