diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py | 84 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py | 106 |
2 files changed, 190 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py new file mode 100644 index 00000000..abb71474 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/handler.py @@ -0,0 +1,84 @@ +""" +Handles the chat completion request for Databricks +""" + +from typing import Callable, List, Optional, Union, cast + +from httpx._config import Timeout + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import CustomStreamingDecoder +from litellm.utils import ModelResponse + +from ...openai_like.chat.handler import OpenAILikeChatHandler +from ..common_utils import DatabricksBase +from .transformation import DatabricksConfig + + +class DatabricksChatCompletion(OpenAILikeChatHandler, DatabricksBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def completion( + self, + *, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: Optional[str], + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + custom_endpoint: Optional[bool] = None, + streaming_decoder: Optional[CustomStreamingDecoder] = None, + fake_stream: bool = False, + ): + messages = DatabricksConfig()._transform_messages( + messages=cast(List[AllMessageValues], messages), model=model + ) + api_base, headers = self.databricks_validate_environment( + api_base=api_base, + api_key=api_key, + endpoint_type="chat_completions", + custom_endpoint=custom_endpoint, + headers=headers, + ) + + if optional_params.get("stream") is True: + fake_stream = DatabricksConfig()._should_fake_stream(optional_params) + else: + fake_stream = False + + return super().completion( + model=model, + messages=messages, + api_base=api_base, + custom_llm_provider=custom_llm_provider, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + client=client, + custom_endpoint=True, + streaming_decoder=streaming_decoder, + fake_stream=fake_stream, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py new file mode 100644 index 00000000..94e02034 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/databricks/chat/transformation.py @@ -0,0 +1,106 @@ +""" +Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` +""" + +from typing import List, Optional, Union + +from pydantic import BaseModel + +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + strip_name_from_messages, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ProviderField + +from ...openai_like.chat.transformation import OpenAILikeChatConfig + + +class DatabricksConfig(OpenAILikeChatConfig): + """ + Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request + """ + + max_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop: Optional[Union[List[str], str]] = None + n: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop: Optional[Union[List[str], str]] = None, + n: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Databricks API Key.", + field_value="dapi...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Databricks API Base.", + field_value="https://adb-..", + ), + ] + + def get_supported_openai_params(self, model: Optional[str] = None) -> list: + return [ + "stream", + "stop", + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "n", + "response_format", + "tools", + "tool_choice", + ] + + def _should_fake_stream(self, optional_params: dict) -> bool: + """ + Databricks doesn't support 'response_format' while streaming + """ + if optional_params.get("response_format") is not None: + return True + + return False + + def _transform_messages( + self, messages: List[AllMessageValues], model: str + ) -> List[AllMessageValues]: + """ + Databricks does not support: + - content in list format. + - 'name' in user message. + """ + new_messages = [] + for idx, message in enumerate(messages): + if isinstance(message, BaseModel): + _message = message.model_dump(exclude_none=True) + else: + _message = message + new_messages.append(_message) + new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) + new_messages = strip_name_from_messages(new_messages) + return super()._transform_messages(messages=new_messages, model=model) |