diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py | 270 |
1 files changed, 270 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py new file mode 100644 index 00000000..d0ab5d06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py @@ -0,0 +1,270 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` + +In the Huggingface TGI format. +""" + +import json +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from httpx._models import Headers, Response + +import litellm +from litellm.litellm_core_utils.asyncify import asyncify +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import SagemakerError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class SagemakerConfig(BaseConfig): + """ + Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb + """ + + max_new_tokens: Optional[int] = None + top_p: Optional[float] = None + temperature: Optional[float] = None + return_full_text: Optional[bool] = None + + def __init__( + self, + max_new_tokens: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + return_full_text: Optional[bool] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return SagemakerError( + message=error_message, status_code=status_code, headers=headers + ) + + def get_supported_openai_params(self, model: str) -> List: + return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + if not non_default_params.get( + "aws_sagemaker_allow_zero_temp", False + ): + value = 0.01 + + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + non_default_params.pop("aws_sagemaker_allow_zero_temp", None) + return optional_params + + def _transform_prompt( + self, + model: str, + messages: List, + custom_prompt_dict: dict, + hf_model_name: Optional[str], + ) -> str: + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + elif hf_model_name in custom_prompt_dict: + # check if the base huggingface model has a registered custom prompt + model_prompt_details = custom_prompt_dict[hf_model_name] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + else: + if hf_model_name is None: + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template + hf_model_name = "meta-llama/Llama-2-7b-chat-hf" + else: # apply regular llama2 template + hf_model_name = "meta-llama/Llama-2-7b" + hf_model_name = ( + hf_model_name or model + ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) + prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore + + return prompt + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + inference_params = optional_params.copy() + stream = inference_params.pop("stream", False) + data: Dict = {"parameters": inference_params} + if stream is True: + data["stream"] = True + + custom_prompt_dict = ( + litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict + ) + + hf_model_name = litellm_params.get("hf_model_name", None) + + prompt = self._transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + data["inputs"] = prompt + + return data + + async def async_transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + return await asyncify(self.transform_request)( + model, messages, optional_params, litellm_params, headers + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + completion_response = raw_response.json() + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_response, + additional_args={"complete_input_dict": request_data}, + ) + + prompt = request_data["inputs"] + + ## RESPONSE OBJECT + try: + if isinstance(completion_response, list): + completion_response_choices = completion_response[0] + else: + completion_response_choices = completion_response + completion_output = "" + if "generation" in completion_response_choices: + completion_output += completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + completion_output += completion_response_choices["generated_text"] + + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(prompt) and "<s>" in prompt: + completion_output = completion_output.replace(prompt, "", 1) + + model_response.choices[0].message.content = completion_output # type: ignore + except Exception: + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def validate_environment( + self, + headers: Optional[dict], + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + headers = {"Content-Type": "application/json"} + + if headers is not None: + headers = {"Content-Type": "application/json", **headers} + + return headers |