aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py270
1 files changed, 270 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
new file mode 100644
index 00000000..d0ab5d06
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/sagemaker/completion/transformation.py
@@ -0,0 +1,270 @@
+"""
+Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke`
+
+In the Huggingface TGI format.
+"""
+
+import json
+import time
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from httpx._models import Headers, Response
+
+import litellm
+from litellm.litellm_core_utils.asyncify import asyncify
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ custom_prompt,
+ prompt_factory,
+)
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, Usage
+
+from ..common_utils import SagemakerError
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+class SagemakerConfig(BaseConfig):
+ """
+ Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb
+ """
+
+ max_new_tokens: Optional[int] = None
+ top_p: Optional[float] = None
+ temperature: Optional[float] = None
+ return_full_text: Optional[bool] = None
+
+ def __init__(
+ self,
+ max_new_tokens: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ return_full_text: Optional[bool] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return SagemakerError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ def get_supported_openai_params(self, model: str) -> List:
+ return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ for param, value in non_default_params.items():
+ if param == "temperature":
+ if value == 0.0 or value == 0:
+ # hugging face exception raised when temp==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
+ if not non_default_params.get(
+ "aws_sagemaker_allow_zero_temp", False
+ ):
+ value = 0.01
+
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "n":
+ optional_params["best_of"] = value
+ optional_params["do_sample"] = (
+ True # Need to sample if you want best of for hf inference endpoints
+ )
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ if param == "max_tokens":
+ # HF TGI raises the following exception when max_new_tokens==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
+ if value == 0:
+ value = 1
+ optional_params["max_new_tokens"] = value
+ non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
+ return optional_params
+
+ def _transform_prompt(
+ self,
+ model: str,
+ messages: List,
+ custom_prompt_dict: dict,
+ hf_model_name: Optional[str],
+ ) -> str:
+ if model in custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+ messages=messages,
+ )
+ elif hf_model_name in custom_prompt_dict:
+ # check if the base huggingface model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[hf_model_name]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
+ messages=messages,
+ )
+ else:
+ if hf_model_name is None:
+ if "llama-2" in model.lower(): # llama-2 model
+ if "chat" in model.lower(): # apply llama2 chat template
+ hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
+ else: # apply regular llama2 template
+ hf_model_name = "meta-llama/Llama-2-7b"
+ hf_model_name = (
+ hf_model_name or model
+ ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
+ prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
+
+ return prompt
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ inference_params = optional_params.copy()
+ stream = inference_params.pop("stream", False)
+ data: Dict = {"parameters": inference_params}
+ if stream is True:
+ data["stream"] = True
+
+ custom_prompt_dict = (
+ litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict
+ )
+
+ hf_model_name = litellm_params.get("hf_model_name", None)
+
+ prompt = self._transform_prompt(
+ model=model,
+ messages=messages,
+ custom_prompt_dict=custom_prompt_dict,
+ hf_model_name=hf_model_name,
+ )
+ data["inputs"] = prompt
+
+ return data
+
+ async def async_transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ return await asyncify(self.transform_request)(
+ model, messages, optional_params, litellm_params, headers
+ )
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: Response,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: str,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ completion_response = raw_response.json()
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_response,
+ additional_args={"complete_input_dict": request_data},
+ )
+
+ prompt = request_data["inputs"]
+
+ ## RESPONSE OBJECT
+ try:
+ if isinstance(completion_response, list):
+ completion_response_choices = completion_response[0]
+ else:
+ completion_response_choices = completion_response
+ completion_output = ""
+ if "generation" in completion_response_choices:
+ completion_output += completion_response_choices["generation"]
+ elif "generated_text" in completion_response_choices:
+ completion_output += completion_response_choices["generated_text"]
+
+ # check if the prompt template is part of output, if so - filter it out
+ if completion_output.startswith(prompt) and "<s>" in prompt:
+ completion_output = completion_output.replace(prompt, "", 1)
+
+ model_response.choices[0].message.content = completion_output # type: ignore
+ except Exception:
+ raise SagemakerError(
+ message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}",
+ status_code=500,
+ )
+
+ ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(model_response["choices"][0]["message"].get("content", ""))
+ )
+
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+
+ def validate_environment(
+ self,
+ headers: Optional[dict],
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ headers = {"Content-Type": "application/json"}
+
+ if headers is not None:
+ headers = {"Content-Type": "application/json", **headers}
+
+ return headers