diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/databricks')
8 files changed, 613 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py new file mode 100644 index 00000000..abb71474 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py @@ -0,0 +1,84 @@ +""" +Handles the chat completion request for Databricks +""" + +from typing import Callable, List, Optional, Union, cast + +from httpx._config import Timeout + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import CustomStreamingDecoder +from litellm.utils import ModelResponse + +from ...openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import DatabricksBase +from .transformation import DatabricksConfig + + +class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + messages = DatabricksConfig()._transform_messages( + messages=cast(List[AllMessageValues], messages), model=model + ) + api_base, headers = self.databricks_validate_environment( + api_base=api_base, + api_key=api_key, + endpoint_type="chat_completions", + custom_endpoint=custom_endpoint, + headers=headers, + ) + + if optional_params.get("stream") is True: + fake_stream = DatabricksConfig()._should_fake_stream(optional_params) + else: + fake_stream = False + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=True, + streaming_decoder=streaming_decoder, + fake_stream=fake_stream, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py new file mode 100644 index 00000000..94e02034 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py @@ -0,0 +1,106 @@ +""" +Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` +""" + +from typing import List, Optional, Union + +from pydantic import BaseModel + +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + strip_name_from_messages, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ProviderField + +from ...openai_like.chat.transformation import OpenAILikeChatConfig + + +class DatabricksConfig(OpenAILikeChatConfig): + """ + Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request + """ + + max_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop: Optional[Union[List[str], str]] = None + n: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop: Optional[Union[List[str], str]] = None, + n: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Databricks API Key.", + field_value="dapi...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Databricks API Base.", + field_value="https://adb-..", + ), + ] + + def get_supported_openai_params(self, model: Optional[str] = None) -> list: + return [ + "stream", + "stop", + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "n", + "response_format", + "tools", + "tool_choice", + ] + + def _should_fake_stream(self, optional_params: dict) -> bool: + """ + Databricks doesn't support 'response_format' while streaming + """ + if optional_params.get("response_format") is not None: + return True + + return False + + def _transform_messages( + self, messages: List[AllMessageValues], model: str + ) -> List[AllMessageValues]: + """ + Databricks does not support: + - content in list format. + - 'name' in user message. + """ + new_messages = [] + for idx, message in enumerate(messages): + if isinstance(message, BaseModel): + _message = message.model_dump(exclude_none=True) + else: + _message = message + new_messages.append(_message) + new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) + new_messages = strip_name_from_messages(new_messages) + return super()._transform_messages(messages=new_messages, model=model) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py new file mode 100644 index 00000000..e8481e25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/common_utils.py @@ -0,0 +1,82 @@ +from typing import Literal, Optional, Tuple + +from .exceptions import DatabricksError + + +class DatabricksBase: + def _get_databricks_credentials( + self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] + ) -> Tuple[str, dict]: + headers = headers or {"Content-Type": "application/json"} + try: + from databricks.sdk import WorkspaceClient + + databricks_client = WorkspaceClient() + + api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" + + if api_key is None: + databricks_auth_headers: dict[str, str] = ( + databricks_client.config.authenticate() + ) + headers = {**databricks_auth_headers, **headers} + + return api_base, headers + except ImportError: + raise DatabricksError( + status_code=400, + message=( + "If the Databricks base URL and API key are not set, the databricks-sdk " + "Python library must be installed. Please install the databricks-sdk, set " + "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " + "or provide the base URL and API key as arguments." + ), + ) + + def databricks_validate_environment( + self, + api_key: Optional[str], + api_base: Optional[str], + endpoint_type: Literal["chat_completions", "embeddings"], + custom_endpoint: Optional[bool], + headers: Optional[dict], + ) -> Tuple[str, dict]: + if api_key is None and headers is None: + if custom_endpoint is not None: + raise DatabricksError( + status_code=400, + message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if api_base is None: + if custom_endpoint: + raise DatabricksError( + status_code=400, + message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", + ) + else: + api_base, headers = self._get_databricks_credentials( + api_base=api_base, api_key=api_key, headers=headers + ) + + if headers is None: + headers = { + "Authorization": "Bearer {}".format(api_key), + "Content-Type": "application/json", + } + else: + if api_key is not None: + headers.update({"Authorization": "Bearer {}".format(api_key)}) + + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + if endpoint_type == "chat_completions" and custom_endpoint is not True: + api_base = "{}/chat/completions".format(api_base) + elif endpoint_type == "embeddings" and custom_endpoint is not True: + api_base = "{}/embeddings".format(api_base) + return api_base, headers diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/cost_calculator.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/cost_calculator.py new file mode 100644 index 00000000..5558e133 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/cost_calculator.py @@ -0,0 +1,66 @@ +""" +Helper util for handling databricks-specific cost calculation +- e.g.: handling 'dbrx-instruct-*' +""" + +from typing import Tuple + +from litellm.types.utils import Usage +from litellm.utils import get_model_info + + +def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]: + """ + Calculates the cost per token for a given model, prompt tokens, and completion tokens. + + Input: + - model: str, the model name without provider prefix + - usage: LiteLLM Usage block, containing anthropic caching information + + Returns: + Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd + """ + base_model = model + if model.startswith("databricks/dbrx-instruct") or model.startswith( + "dbrx-instruct" + ): + base_model = "databricks-dbrx-instruct" + elif model.startswith("databricks/meta-llama-3.1-70b-instruct") or model.startswith( + "meta-llama-3.1-70b-instruct" + ): + base_model = "databricks-meta-llama-3-1-70b-instruct" + elif model.startswith( + "databricks/meta-llama-3.1-405b-instruct" + ) or model.startswith("meta-llama-3.1-405b-instruct"): + base_model = "databricks-meta-llama-3-1-405b-instruct" + elif model.startswith("databricks/mixtral-8x7b-instruct-v0.1") or model.startswith( + "mixtral-8x7b-instruct-v0.1" + ): + base_model = "databricks-mixtral-8x7b-instruct" + elif model.startswith("databricks/mixtral-8x7b-instruct-v0.1") or model.startswith( + "mixtral-8x7b-instruct-v0.1" + ): + base_model = "databricks-mixtral-8x7b-instruct" + elif model.startswith("databricks/bge-large-en") or model.startswith( + "bge-large-en" + ): + base_model = "databricks-bge-large-en" + elif model.startswith("databricks/gte-large-en") or model.startswith( + "gte-large-en" + ): + base_model = "databricks-gte-large-en" + elif model.startswith("databricks/llama-2-70b-chat") or model.startswith( + "llama-2-70b-chat" + ): + base_model = "databricks-llama-2-70b-chat" + ## GET MODEL INFO + model_info = get_model_info(model=base_model, custom_llm_provider="databricks") + + ## CALCULATE INPUT COST + + prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"] + + ## CALCULATE OUTPUT COST + completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"] + + return prompt_cost, completion_cost diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py new file mode 100644 index 00000000..2eabcdbc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/handler.py @@ -0,0 +1,49 @@ +""" +Calling logic for Databricks embeddings +""" + +from typing import Optional + +from litellm.utils import EmbeddingResponse + +from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler +from ..common_utils import DatabricksBase + + +class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase): + def embedding( + self, + model: str, + input: list, + timeout: float, + logging_obj, + api_key: Optional[str], + api_base: Optional[str], + optional_params: dict, + model_response: Optional[EmbeddingResponse] = None, + client=None, + aembedding=None, + custom_endpoint: Optional[bool] = None, + headers: Optional[dict] = None, + ) -> EmbeddingResponse: + api_base, headers = self.databricks_validate_environment( + api_base=api_base, + api_key=api_key, + endpoint_type="embeddings", + custom_endpoint=custom_endpoint, + headers=headers, + ) + return super().embedding( + model=model, + input=input, + timeout=timeout, + logging_obj=logging_obj, + api_key=api_key, + api_base=api_base, + optional_params=optional_params, + model_response=model_response, + client=client, + aembedding=aembedding, + custom_endpoint=True, + headers=headers, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/transformation.py new file mode 100644 index 00000000..53e3b30d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/embed/transformation.py @@ -0,0 +1,48 @@ +""" +Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings` +""" + +import types +from typing import Optional + + +class DatabricksEmbeddingConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task + """ + + instruction: Optional[str] = ( + None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries + ) + + def __init__(self, instruction: Optional[str] = None) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params( + self, + ): # no optional openai embedding params supported + return [] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + return optional_params diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/exceptions.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/exceptions.py new file mode 100644 index 00000000..8bb3d435 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/exceptions.py @@ -0,0 +1,12 @@ +import httpx + + +class DatabricksError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://docs.databricks.com/") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/streaming_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/streaming_utils.py new file mode 100644 index 00000000..2db53df9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/streaming_utils.py @@ -0,0 +1,166 @@ +import json +from typing import Optional + +import litellm +from litellm import verbose_logger +from litellm.types.llms.openai import ( + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionUsageBlock, +) +from litellm.types.utils import GenericStreamingChunk, Usage + + +class ModelResponseIterator: + def __init__(self, streaming_response, sync_stream: bool): + self.streaming_response = streaming_response + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = litellm.ModelResponseStream(**chunk) + + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + + if processed_chunk.choices[0].delta.content is not None: # type: ignore + text = processed_chunk.choices[0].delta.content # type: ignore + + if ( + processed_chunk.choices[0].delta.tool_calls is not None # type: ignore + and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore + is not None + ): + tool_use = ChatCompletionToolCallChunk( + id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.name, + arguments=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.arguments, + ), + index=processed_chunk.choices[0].delta.tool_calls[0].index, + ) + + if processed_chunk.choices[0].finish_reason is not None: + is_finished = True + finish_reason = processed_chunk.choices[0].finish_reason + + usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None) + if usage_chunk is not None: + + usage = ChatCompletionUsageBlock( + prompt_tokens=usage_chunk.prompt_tokens, + completion_tokens=usage_chunk.completion_tokens, + total_tokens=usage_chunk.total_tokens, + ) + + return GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + self.response_iterator = self.streaming_response + return self + + def __next__(self): + if not hasattr(self, "response_iterator"): + self.response_iterator = self.streaming_response + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopIteration: + raise StopIteration + except ValueError as e: + verbose_logger.debug( + f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here." + ) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + except Exception as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + chunk = chunk.strip() + if chunk == "[DONE]": + raise StopAsyncIteration + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + verbose_logger.debug( + f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here." + ) + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) |