aboutsummaryrefslogtreecommitdiff
"""
Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`

In the Huggingface TGI format. 
"""

import json
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from httpx._models import Headers, Response

import litellm
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.litellm_core_utils.prompt_templates.factory import (
    custom_prompt,
    prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage

from ..common_utils import SagemakerError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

    LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
    LiteLLMLoggingObj = Any


class SagemakerConfig(BaseConfig):
    """
    Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
    """

    max_new_tokens: Optional[int] = None
    top_p: Optional[float] = None
    temperature: Optional[float] = None
    return_full_text: Optional[bool] = None

    def __init__(
        self,
        max_new_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        temperature: Optional[float] = None,
        return_full_text: Optional[bool] = None,
    ) -> None:
        locals_ = locals().copy()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return super().get_config()

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, Headers]
    ) -> BaseLLMException:
        return SagemakerError(
            message=error_message, status_code=status_code, headers=headers
        )

    def get_supported_openai_params(self, model: str) -> List:
        return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
    ) -> dict:
        for param, value in non_default_params.items():
            if param == "temperature":
                if value == 0.0 or value == 0:
                    # hugging face exception raised when temp==0
                    # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
                    if not non_default_params.get(
                        "aws_sagemaker_allow_zero_temp", False
                    ):
                        value = 0.01

                optional_params["temperature"] = value
            if param == "top_p":
                optional_params["top_p"] = value
            if param == "n":
                optional_params["best_of"] = value
                optional_params["do_sample"] = (
                    True  # Need to sample if you want best of for hf inference endpoints
                )
            if param == "stream":
                optional_params["stream"] = value
            if param == "stop":
                optional_params["stop"] = value
            if param == "max_tokens":
                # HF TGI raises the following exception when max_new_tokens==0
                # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
                if value == 0:
                    value = 1
                optional_params["max_new_tokens"] = value
        non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
        return optional_params

    def _transform_prompt(
        self,
        model: str,
        messages: List,
        custom_prompt_dict: dict,
        hf_model_name: Optional[str],
    ) -> str:
        if model in custom_prompt_dict:
            # check if the model has a registered custom prompt
            model_prompt_details = custom_prompt_dict[model]
            prompt = custom_prompt(
                role_dict=model_prompt_details.get("roles", None),
                initial_prompt_value=model_prompt_details.get(
                    "initial_prompt_value", ""
                ),
                final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
                messages=messages,
            )
        elif hf_model_name in custom_prompt_dict:
            # check if the base huggingface model has a registered custom prompt
            model_prompt_details = custom_prompt_dict[hf_model_name]
            prompt = custom_prompt(
                role_dict=model_prompt_details.get("roles", None),
                initial_prompt_value=model_prompt_details.get(
                    "initial_prompt_value", ""
                ),
                final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
                messages=messages,
            )
        else:
            if hf_model_name is None:
                if "llama-2" in model.lower():  # llama-2 model
                    if "chat" in model.lower():  # apply llama2 chat template
                        hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
                    else:  # apply regular llama2 template
                        hf_model_name = "meta-llama/Llama-2-7b"
            hf_model_name = (
                hf_model_name or model
            )  # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
            prompt: str = prompt_factory(model=hf_model_name, messages=messages)  # type: ignore

        return prompt

    def transform_request(
        self,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        headers: dict,
    ) -> dict:
        inference_params = optional_params.copy()
        stream = inference_params.pop("stream", False)
        data: Dict = {"parameters": inference_params}
        if stream is True:
            data["stream"] = True

        custom_prompt_dict = (
            litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict
        )

        hf_model_name = litellm_params.get("hf_model_name", None)

        prompt = self._transform_prompt(
            model=model,
            messages=messages,
            custom_prompt_dict=custom_prompt_dict,
            hf_model_name=hf_model_name,
        )
        data["inputs"] = prompt

        return data

    async def async_transform_request(
        self,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        headers: dict,
    ) -> dict:
        return await asyncify(self.transform_request)(
            model, messages, optional_params, litellm_params, headers
        )

    def transform_response(
        self,
        model: str,
        raw_response: Response,
        model_response: ModelResponse,
        logging_obj: LiteLLMLoggingObj,
        request_data: dict,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        encoding: str,
        api_key: Optional[str] = None,
        json_mode: Optional[bool] = None,
    ) -> ModelResponse:
        completion_response = raw_response.json()
        ## LOGGING
        logging_obj.post_call(
            input=messages,
            api_key="",
            original_response=completion_response,
            additional_args={"complete_input_dict": request_data},
        )

        prompt = request_data["inputs"]

        ## RESPONSE OBJECT
        try:
            if isinstance(completion_response, list):
                completion_response_choices = completion_response[0]
            else:
                completion_response_choices = completion_response
            completion_output = ""
            if "generation" in completion_response_choices:
                completion_output += completion_response_choices["generation"]
            elif "generated_text" in completion_response_choices:
                completion_output += completion_response_choices["generated_text"]

            # check if the prompt template is part of output, if so - filter it out
            if completion_output.startswith(prompt) and "<s>" in prompt:
                completion_output = completion_output.replace(prompt, "", 1)

            model_response.choices[0].message.content = completion_output  # type: ignore
        except Exception:
            raise SagemakerError(
                message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
                status_code=500,
            )

        ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
        prompt_tokens = len(encoding.encode(prompt))
        completion_tokens = len(
            encoding.encode(model_response["choices"][0]["message"].get("content", ""))
        )

        model_response.created = int(time.time())
        model_response.model = model
        usage = Usage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        )
        setattr(model_response, "usage", usage)
        return model_response

    def validate_environment(
        self,
        headers: Optional[dict],
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        headers = {"Content-Type": "application/json"}

        if headers is not None:
            headers = {"Content-Type": "application/json", **headers}

        return headers