aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py
new file mode 100644
index 00000000..4744ec08
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/triton/embedding/transformation.py
@@ -0,0 +1,123 @@
+from typing import List, Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
+from litellm.llms.base_llm.embedding.transformation import (
+ BaseEmbeddingConfig,
+ LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllEmbeddingInputValues
+from litellm.types.utils import EmbeddingResponse
+
+from ..common_utils import TritonError
+
+
+class TritonEmbeddingConfig(BaseEmbeddingConfig):
+ """
+ Transformations for triton /embeddings endpoint (This is a trtllm model)
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def get_supported_openai_params(self, model: str) -> list:
+ return []
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ """
+ Map OpenAI params to Triton Embedding params
+ """
+ return optional_params
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ return {}
+
+ def transform_embedding_request(
+ self,
+ model: str,
+ input: AllEmbeddingInputValues,
+ optional_params: dict,
+ headers: dict,
+ ) -> dict:
+ return {
+ "inputs": [
+ {
+ "name": "input_text",
+ "shape": [len(input)],
+ "datatype": "BYTES",
+ "data": input,
+ }
+ ]
+ }
+
+ def transform_embedding_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: dict = {},
+ optional_params: dict = {},
+ litellm_params: dict = {},
+ ) -> EmbeddingResponse:
+ try:
+ raw_response_json = raw_response.json()
+ except Exception:
+ raise TritonError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
+
+ _embedding_output = []
+
+ _outputs = raw_response_json["outputs"]
+ for output in _outputs:
+ _shape = output["shape"]
+ _data = output["data"]
+ _split_output_data = self.split_embedding_by_shape(_data, _shape)
+
+ for idx, embedding in enumerate(_split_output_data):
+ _embedding_output.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding,
+ }
+ )
+
+ model_response.model = raw_response_json.get("model_name", "None")
+ model_response.data = _embedding_output
+ return model_response
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return TritonError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ @staticmethod
+ def split_embedding_by_shape(
+ data: List[float], shape: List[int]
+ ) -> List[List[float]]:
+ if len(shape) != 2:
+ raise ValueError("Shape must be of length 2.")
+ embedding_size = shape[1]
+ return [
+ data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
+ ]