about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py270
1 files changed, 270 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
new file mode 100644
index 00000000..d0ab5d06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
@@ -0,0 +1,270 @@
+"""
+Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
+
+In the Huggingface TGI format. 
+"""
+
+import json
+import time
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from httpx._models import Headers, Response
+
+import litellm
+from litellm.litellm_core_utils.asyncify import asyncify
+from litellm.litellm_core_utils.prompt_templates.factory import (
+    custom_prompt,
+    prompt_factory,
+)
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, Usage
+
+from ..common_utils import SagemakerError
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+    LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+    LiteLLMLoggingObj = Any
+
+
+class SagemakerConfig(BaseConfig):
+    """
+    Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
+    """
+
+    max_new_tokens: Optional[int] = None
+    top_p: Optional[float] = None
+    temperature: Optional[float] = None
+    return_full_text: Optional[bool] = None
+
+    def __init__(
+        self,
+        max_new_tokens: Optional[int] = None,
+        top_p: Optional[float] = None,
+        temperature: Optional[float] = None,
+        return_full_text: Optional[bool] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return super().get_config()
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, Headers]
+    ) -> BaseLLMException:
+        return SagemakerError(
+            message=error_message, status_code=status_code, headers=headers
+        )
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for param, value in non_default_params.items():
+            if param == "temperature":
+                if value == 0.0 or value == 0:
+                    # hugging face exception raised when temp==0
+                    # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
+                    if not non_default_params.get(
+                        "aws_sagemaker_allow_zero_temp", False
+                    ):
+                        value = 0.01
+
+                optional_params["temperature"] = value
+            if param == "top_p":
+                optional_params["top_p"] = value
+            if param == "n":
+                optional_params["best_of"] = value
+                optional_params["do_sample"] = (
+                    True  # Need to sample if you want best of for hf inference endpoints
+                )
+            if param == "stream":
+                optional_params["stream"] = value
+            if param == "stop":
+                optional_params["stop"] = value
+            if param == "max_tokens":
+                # HF TGI raises the following exception when max_new_tokens==0
+                # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
+                if value == 0:
+                    value = 1
+                optional_params["max_new_tokens"] = value
+        non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
+        return optional_params
+
+    def _transform_prompt(
+        self,
+        model: str,
+        messages: List,
+        custom_prompt_dict: dict,
+        hf_model_name: Optional[str],
+    ) -> str:
+        if model in custom_prompt_dict:
+            # check if the model has a registered custom prompt
+            model_prompt_details = custom_prompt_dict[model]
+            prompt = custom_prompt(
+                role_dict=model_prompt_details.get("roles", None),
+                initial_prompt_value=model_prompt_details.get(
+                    "initial_prompt_value", ""
+                ),
+                final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+                messages=messages,
+            )
+        elif hf_model_name in custom_prompt_dict:
+            # check if the base huggingface model has a registered custom prompt
+            model_prompt_details = custom_prompt_dict[hf_model_name]
+            prompt = custom_prompt(
+                role_dict=model_prompt_details.get("roles", None),
+                initial_prompt_value=model_prompt_details.get(
+                    "initial_prompt_value", ""
+                ),
+                final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+                messages=messages,
+            )
+        else:
+            if hf_model_name is None:
+                if "llama-2" in model.lower():  # llama-2 model
+                    if "chat" in model.lower():  # apply llama2 chat template
+                        hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
+                    else:  # apply regular llama2 template
+                        hf_model_name = "meta-llama/Llama-2-7b"
+            hf_model_name = (
+                hf_model_name or model
+            )  # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
+            prompt: str = prompt_factory(model=hf_model_name, messages=messages)  # type: ignore
+
+        return prompt
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        inference_params = optional_params.copy()
+        stream = inference_params.pop("stream", False)
+        data: Dict = {"parameters": inference_params}
+        if stream is True:
+            data["stream"] = True
+
+        custom_prompt_dict = (
+            litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict
+        )
+
+        hf_model_name = litellm_params.get("hf_model_name", None)
+
+        prompt = self._transform_prompt(
+            model=model,
+            messages=messages,
+            custom_prompt_dict=custom_prompt_dict,
+            hf_model_name=hf_model_name,
+        )
+        data["inputs"] = prompt
+
+        return data
+
+    async def async_transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        return await asyncify(self.transform_request)(
+            model, messages, optional_params, litellm_params, headers
+        )
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: str,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        completion_response = raw_response.json()
+        ## LOGGING
+        logging_obj.post_call(
+            input=messages,
+            api_key="",
+            original_response=completion_response,
+            additional_args={"complete_input_dict": request_data},
+        )
+
+        prompt = request_data["inputs"]
+
+        ## RESPONSE OBJECT
+        try:
+            if isinstance(completion_response, list):
+                completion_response_choices = completion_response[0]
+            else:
+                completion_response_choices = completion_response
+            completion_output = ""
+            if "generation" in completion_response_choices:
+                completion_output += completion_response_choices["generation"]
+            elif "generated_text" in completion_response_choices:
+                completion_output += completion_response_choices["generated_text"]
+
+            # check if the prompt template is part of output, if so - filter it out
+            if completion_output.startswith(prompt) and "<s>" in prompt:
+                completion_output = completion_output.replace(prompt, "", 1)
+
+            model_response.choices[0].message.content = completion_output  # type: ignore
+        except Exception:
+            raise SagemakerError(
+                message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
+                status_code=500,
+            )
+
+        ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
+        prompt_tokens = len(encoding.encode(prompt))
+        completion_tokens = len(
+            encoding.encode(model_response["choices"][0]["message"].get("content", ""))
+        )
+
+        model_response.created = int(time.time())
+        model_response.model = model
+        usage = Usage(
+            prompt_tokens=prompt_tokens,
+            completion_tokens=completion_tokens,
+            total_tokens=prompt_tokens + completion_tokens,
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def validate_environment(
+        self,
+        headers: Optional[dict],
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        headers = {"Content-Type": "application/json"}
+
+        if headers is not None:
+            headers = {"Content-Type": "application/json", **headers}
+
+        return headers