diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py new file mode 100644 index 00000000..e8481e25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py @@ -0,0 +1,82 @@ +from typing import Literal, Optional, Tuple + +from .exceptions import DatabricksError + + +class DatabricksBase: + def _get_databricks_credentials( + self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] + ) -> Tuple[str, dict]: + headers = headers or {"Content-Type": "application/json"} + try: + from databricks.sdk import WorkspaceClient + + databricks_client = WorkspaceClient() + + api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" + + if api_key is None: + databricks_auth_headers: dict[str, str] = ( + databricks_client.config.authenticate() + ) + headers = {**databricks_auth_headers, **headers} + + return api_base, headers + except ImportError: + raise DatabricksError( + status_code=400, + message=( + "If the Databricks base URL and API key are not set, the databricks-sdk " + "Python library must be installed. Please install the databricks-sdk, set " + "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " + "or provide the base URL and API key as arguments." + ), + ) + + def databricks_validate_environment( + self, + api_key: Optional[str], + api_base: Optional[str], + endpoint_type: Literal["chat_completions", "embeddings"], + custom_endpoint: Optional[bool], + headers: Optional[dict], + ) -> Tuple[str, dict]: + if api_key is None and headers is None: + if custom_endpoint is not None: + raise DatabricksError( + status_code=400, + message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if api_base is None: + if custom_endpoint: + raise DatabricksError( + status_code=400, + message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if headers is None: + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + else: + if api_key is not None: + headers.update({"Authorization": "Bearer {}".format(api_key)}) + + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + if endpoint_type == "chat_completions" and custom_endpoint is not True: + api_base = "{}/chat/completions".format(api_base) + elif endpoint_type == "embeddings" and custom_endpoint is not True: + api_base = "{}/embeddings".format(api_base) + return api_base, headers |