diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py new file mode 100644 index 00000000..4744ec08 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py @@ -0,0 +1,123 @@ +from typing import List, Optional, Union + +import httpx + +from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException +from litellm.llms.base_llm.embedding.transformation import ( + BaseEmbeddingConfig, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import AllEmbeddingInputValues +from litellm.types.utils import EmbeddingResponse + +from ..common_utils import TritonError + + +class TritonEmbeddingConfig(BaseEmbeddingConfig): + """ + Transformations for triton /embeddings endpoint (This is a trtllm model) + """ + + def __init__(self) -> None: + pass + + def get_supported_openai_params(self, model: str) -> list: + return [] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Map OpenAI params to Triton Embedding params + """ + return optional_params + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return {} + + def transform_embedding_request( + self, + model: str, + input: AllEmbeddingInputValues, + optional_params: dict, + headers: dict, + ) -> dict: + return { + "inputs": [ + { + "name": "input_text", + "shape": [len(input)], + "datatype": "BYTES", + "data": input, + } + ] + } + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> EmbeddingResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise TritonError( + message=raw_response.text, status_code=raw_response.status_code + ) + + _embedding_output = [] + + _outputs = raw_response_json["outputs"] + for output in _outputs: + _shape = output["shape"] + _data = output["data"] + _split_output_data = self.split_embedding_by_shape(_data, _shape) + + for idx, embedding in enumerate(_split_output_data): + _embedding_output.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding, + } + ) + + model_response.model = raw_response_json.get("model_name", "None") + model_response.data = _embedding_output + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return TritonError( + message=error_message, status_code=status_code, headers=headers + ) + + @staticmethod + def split_embedding_by_shape( + data: List[float], shape: List[int] + ) -> List[List[float]]: + if len(shape) != 2: + raise ValueError("Shape must be of length 2.") + embedding_size = shape[1] + return [ + data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0]) + ] |