diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py | 343 |
1 files changed, 343 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py new file mode 100644 index 00000000..56151f89 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/completion/transformation.py @@ -0,0 +1,343 @@ +""" +Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint. +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union + +from httpx import Headers, Response + +from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator +from litellm.llms.base_llm.chat.transformation import ( + BaseConfig, + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + Choices, + GenericStreamingChunk, + Message, + ModelResponse, +) + +from ..common_utils import TritonError + + +class TritonConfig(BaseConfig): + """ + Base class for Triton configurations. + + Handles routing between /infer and /generate triton completion llms + """ + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, Headers] + ) -> BaseLLMException: + return TritonError( + status_code=status_code, message=error_message, headers=headers + ) + + def validate_environment( + self, + headers: Dict, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Dict: + return {"Content-Type": "application/json"} + + def get_supported_openai_params(self, model: str) -> List: + return ["max_tokens", "max_completion_tokens"] + + def map_openai_params( + self, + non_default_params: Dict, + optional_params: Dict, + model: str, + drop_params: bool, + ) -> Dict: + for param, value in non_default_params.items(): + if param == "max_tokens" or param == "max_completion_tokens": + optional_params[param] = value + return optional_params + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + if api_base is None: + raise ValueError("api_base is required") + llm_type = self._get_triton_llm_type(api_base) + if llm_type == "generate" and stream: + return api_base + "_stream" + return api_base + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + api_base = litellm_params.get("api_base", "") + llm_type = self._get_triton_llm_type(api_base) + if llm_type == "generate": + return TritonGenerateConfig().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) + elif llm_type == "infer": + return TritonInferConfig().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) + return model_response + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + api_base = litellm_params.get("api_base", "") + llm_type = self._get_triton_llm_type(api_base) + if llm_type == "generate": + return TritonGenerateConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + elif llm_type == "infer": + return TritonInferConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + return {} + + def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]: + if api_base.endswith("/generate"): + return "generate" + elif api_base.endswith("/infer"): + return "infer" + else: + raise ValueError(f"Invalid Triton API base: {api_base}") + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return TritonResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class TritonGenerateConfig(TritonConfig): + """ + Transformations for triton /generate endpoint (This is a trtllm model) + """ + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + inference_params = optional_params.copy() + stream = inference_params.pop("stream", False) + data_for_triton: Dict[str, Any] = { + "text_input": prompt_factory(model=model, messages=messages), + "parameters": { + "max_tokens": int(optional_params.get("max_tokens", 2000)), + "bad_words": [""], + "stop_words": [""], + }, + "stream": bool(stream), + } + data_for_triton["parameters"].update(inference_params) + return data_for_triton + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise TritonError( + message=raw_response.text, status_code=raw_response.status_code + ) + model_response.choices = [ + Choices(index=0, message=Message(content=raw_response_json["text_output"])) + ] + + return model_response + + +class TritonInferConfig(TritonConfig): + """ + Transformations for triton /infer endpoint (his is an infer model with a custom model on triton) + """ + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + + text_input = messages[0].get("content", "") + data_for_triton = { + "inputs": [ + { + "name": "text_input", + "shape": [1], + "datatype": "BYTES", + "data": [text_input], + } + ] + } + + for k, v in optional_params.items(): + if not (k == "stream" or k == "max_retries"): + datatype = "INT32" if isinstance(v, int) else "BYTES" + datatype = "FP32" if isinstance(v, float) else datatype + data_for_triton["inputs"].append( + {"name": k, "shape": [1], "datatype": datatype, "data": [v]} + ) + + if "max_tokens" not in optional_params: + data_for_triton["inputs"].append( + { + "name": "max_tokens", + "shape": [1], + "datatype": "INT32", + "data": [20], + } + ) + return data_for_triton + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise TritonError( + message=raw_response.text, status_code=raw_response.status_code + ) + + _triton_response_data = raw_response_json["outputs"][0]["data"] + triton_response_data: Optional[str] = None + if isinstance(_triton_response_data, list): + triton_response_data = "".join(_triton_response_data) + else: + triton_response_data = _triton_response_data + + model_response.choices = [ + Choices( + index=0, + message=Message(content=triton_response_data), + ) + ] + + return model_response + + +class TritonResponseIterator(BaseModelResponseIterator): + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + index = int(chunk.get("index", 0)) + + # set values + text = chunk.get("text_output", "") + finish_reason = chunk.get("stop_reason", "") + is_finished = chunk.get("is_finished", False) + + return GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") |