about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py
new file mode 100644
index 00000000..2eabcdbc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py
@@ -0,0 +1,49 @@
+"""
+Calling logic for Databricks embeddings
+"""
+
+from typing import Optional
+
+from litellm.utils import EmbeddingResponse
+
+from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler
+from ..common_utils import DatabricksBase
+
+
+class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
+    def embedding(
+        self,
+        model: str,
+        input: list,
+        timeout: float,
+        logging_obj,
+        api_key: Optional[str],
+        api_base: Optional[str],
+        optional_params: dict,
+        model_response: Optional[EmbeddingResponse] = None,
+        client=None,
+        aembedding=None,
+        custom_endpoint: Optional[bool] = None,
+        headers: Optional[dict] = None,
+    ) -> EmbeddingResponse:
+        api_base, headers = self.databricks_validate_environment(
+            api_base=api_base,
+            api_key=api_key,
+            endpoint_type="embeddings",
+            custom_endpoint=custom_endpoint,
+            headers=headers,
+        )
+        return super().embedding(
+            model=model,
+            input=input,
+            timeout=timeout,
+            logging_obj=logging_obj,
+            api_key=api_key,
+            api_base=api_base,
+            optional_params=optional_params,
+            model_response=model_response,
+            client=client,
+            aembedding=aembedding,
+            custom_endpoint=True,
+            headers=headers,
+        )