about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py
new file mode 100644
index 00000000..4744ec08
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py
@@ -0,0 +1,123 @@
+from typing import List, Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
+from litellm.llms.base_llm.embedding.transformation import (
+    BaseEmbeddingConfig,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllEmbeddingInputValues
+from litellm.types.utils import EmbeddingResponse
+
+from ..common_utils import TritonError
+
+
+class TritonEmbeddingConfig(BaseEmbeddingConfig):
+    """
+    Transformations for triton /embeddings endpoint (This is a trtllm model)
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_supported_openai_params(self, model: str) -> list:
+        return []
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        """
+        Map OpenAI params to Triton Embedding params
+        """
+        return optional_params
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        return {}
+
+    def transform_embedding_request(
+        self,
+        model: str,
+        input: AllEmbeddingInputValues,
+        optional_params: dict,
+        headers: dict,
+    ) -> dict:
+        return {
+            "inputs": [
+                {
+                    "name": "input_text",
+                    "shape": [len(input)],
+                    "datatype": "BYTES",
+                    "data": input,
+                }
+            ]
+        }
+
+    def transform_embedding_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: EmbeddingResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> EmbeddingResponse:
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise TritonError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        _embedding_output = []
+
+        _outputs = raw_response_json["outputs"]
+        for output in _outputs:
+            _shape = output["shape"]
+            _data = output["data"]
+            _split_output_data = self.split_embedding_by_shape(_data, _shape)
+
+            for idx, embedding in enumerate(_split_output_data):
+                _embedding_output.append(
+                    {
+                        "object": "embedding",
+                        "index": idx,
+                        "embedding": embedding,
+                    }
+                )
+
+        model_response.model = raw_response_json.get("model_name", "None")
+        model_response.data = _embedding_output
+        return model_response
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return TritonError(
+            message=error_message, status_code=status_code, headers=headers
+        )
+
+    @staticmethod
+    def split_embedding_by_shape(
+        data: List[float], shape: List[int]
+    ) -> List[List[float]]:
+        if len(shape) != 2:
+            raise ValueError("Shape must be of length 2.")
+        embedding_size = shape[1]
+        return [
+            data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
+        ]