about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py
new file mode 100644
index 00000000..359137ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/watsonx/embed/transformation.py
@@ -0,0 +1,112 @@
+"""
+Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
+"""
+
+from typing import Optional
+
+import httpx
+
+from litellm.llms.base_llm.embedding.transformation import (
+    BaseEmbeddingConfig,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllEmbeddingInputValues
+from litellm.types.llms.watsonx import WatsonXAIEndpoint
+from litellm.types.utils import EmbeddingResponse, Usage
+
+from ..common_utils import IBMWatsonXMixin, _get_api_params
+
+
+class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
+    def get_supported_openai_params(self, model: str) -> list:
+        return []
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        return optional_params
+
+    def transform_embedding_request(
+        self,
+        model: str,
+        input: AllEmbeddingInputValues,
+        optional_params: dict,
+        headers: dict,
+    ) -> dict:
+        watsonx_api_params = _get_api_params(params=optional_params)
+        watsonx_auth_payload = self._prepare_payload(
+            model=model,
+            api_params=watsonx_api_params,
+        )
+
+        return {
+            "inputs": input,
+            "parameters": optional_params,
+            **watsonx_auth_payload,
+        }
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        url = self._get_base_url(api_base=api_base)
+        endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
+        if model.startswith("deployment/"):
+            deployment_id = "/".join(model.split("/")[1:])
+            endpoint = endpoint.format(deployment_id=deployment_id)
+        url = url.rstrip("/") + endpoint
+
+        ## add api version
+        url = self._add_api_version_to_url(
+            url=url, api_version=optional_params.pop("api_version", None)
+        )
+        return url
+
+    def transform_embedding_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: EmbeddingResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str],
+        request_data: dict,
+        optional_params: dict,
+        litellm_params: dict,
+    ) -> EmbeddingResponse:
+        logging_obj.post_call(
+            original_response=raw_response.text,
+        )
+        json_resp = raw_response.json()
+        if model_response is None:
+            model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
+        results = json_resp.get("results", [])
+        embedding_response = []
+        for idx, result in enumerate(results):
+            embedding_response.append(
+                {
+                    "object": "embedding",
+                    "index": idx,
+                    "embedding": result["embedding"],
+                }
+            )
+        model_response.object = "list"
+        model_response.data = embedding_response
+        input_tokens = json_resp.get("input_token_count", 0)
+        setattr(
+            model_response,
+            "usage",
+            Usage(
+                prompt_tokens=input_tokens,
+                completion_tokens=0,
+                total_tokens=input_tokens,
+            ),
+        )
+        return model_response