diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py | 472 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py | 180 |
2 files changed, 652 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py new file mode 100644 index 00000000..43f4b067 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py @@ -0,0 +1,472 @@ +# What is this? +## Controller file for Predibase Integration - https://predibase.com/ + +import json +import os +import time +from functools import partial +from typing import Callable, Optional, Union + +import httpx # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) +from litellm.types.utils import LiteLLMLoggingBaseClass +from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage + +from ..common_utils import PredibaseError + + +async def make_call( + client: AsyncHTTPHandler, + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], +): + response = await client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + + if response.status_code != 200: + raise PredibaseError(status_code=response.status_code, message=response.text) + + completion_stream = response.aiter_lines() + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class PredibaseChatCompletion: + def __init__(self) -> None: + super().__init__() + + def output_parser(self, generated_text: str): + """ + Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. + + Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 + """ + chat_template_tokens = [ + "<|assistant|>", + "<|system|>", + "<|user|>", + "<s>", + "</s>", + ] + for token in chat_template_tokens: + if generated_text.strip().startswith(token): + generated_text = generated_text.replace(token, "", 1) + if generated_text.endswith(token): + generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] + return generated_text + + def process_response( # noqa: PLR0915 + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: LiteLLMLoggingBaseClass, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except Exception: + raise PredibaseError(message=response.text, status_code=422) + if "error" in completion_response: + raise PredibaseError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + if not isinstance(completion_response, dict): + raise PredibaseError( + status_code=422, + message=f"'completion_response' is not a dictionary - {completion_response}", + ) + elif "generated_text" not in completion_response: + raise PredibaseError( + status_code=422, + message=f"'generated_text' is not a key response dictionary - {completion_response}", + ) + if len(completion_response["generated_text"]) > 0: + model_response.choices[0].message.content = self.output_parser( # type: ignore + completion_response["generated_text"] + ) + ## GETTING LOGPROBS + FINISH REASON + if ( + "details" in completion_response + and "tokens" in completion_response["details"] + ): + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["details"]["finish_reason"] + ) + sum_logprob = 0 + for token in completion_response["details"]["tokens"]: + if token["logprob"] is not None: + sum_logprob += token["logprob"] + setattr( + model_response.choices[0].message, # type: ignore + "_logprob", + sum_logprob, # [TODO] move this to using the actual logprobs + ) + if "best_of" in optional_params and optional_params["best_of"] > 1: + if ( + "details" in completion_response + and "best_of_sequences" in completion_response["details"] + ): + choices_list = [] + for idx, item in enumerate( + completion_response["details"]["best_of_sequences"] + ): + sum_logprob = 0 + for token in item["tokens"]: + if token["logprob"] is not None: + sum_logprob += token["logprob"] + if len(item["generated_text"]) > 0: + message_obj = Message( + content=self.output_parser(item["generated_text"]), + logprobs=sum_logprob, + ) + else: + message_obj = Message(content=None) + choice_obj = Choices( + finish_reason=map_finish_reason(item["finish_reason"]), + index=idx + 1, + message=message_obj, + ) + choices_list.append(choice_obj) + model_response.choices.extend(choices_list) + + ## CALCULATING USAGE + prompt_tokens = 0 + try: + prompt_tokens = litellm.token_counter(messages=messages) + except Exception: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + output_text = model_response["choices"][0]["message"].get("content", "") + if output_text is not None and len(output_text) > 0: + completion_tokens = 0 + try: + completion_tokens = len( + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) + ) ##[TODO] use a model-specific tokenizer + except Exception: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + else: + completion_tokens = 0 + + total_tokens = prompt_tokens + completion_tokens + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + model_response.usage = usage # type: ignore + + ## RESPONSE HEADERS + predibase_headers = response.headers + response_headers = {} + for k, v in predibase_headers.items(): + if k.startswith("x-"): + response_headers["llm_provider-{}".format(k)] = v + + model_response._hidden_params["additional_headers"] = response_headers + + return model_response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: str, + logging_obj, + optional_params: dict, + tenant_id: str, + timeout: Union[float, httpx.Timeout], + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: dict = {}, + ) -> Union[ModelResponse, CustomStreamWrapper]: + headers = litellm.PredibaseConfig().validate_environment( + api_key=api_key, + headers=headers, + messages=messages, + optional_params=optional_params, + model=model, + ) + completion_url = "" + input_text = "" + base_url = "https://serving.app.predibase.com" + + if "https" in model: + completion_url = model + elif api_base: + base_url = api_base + elif "PREDIBASE_API_BASE" in os.environ: + base_url = os.getenv("PREDIBASE_API_BASE", "") + + completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" + + if optional_params.get("stream", False) is True: + completion_url += "/generate_stream" + else: + completion_url += "/generate" + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + ## Load Config + config = litellm.PredibaseConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", False) + + data = { + "inputs": prompt, + "parameters": optional_params, + } + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + else: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + + ### SYNC STREAMING + if stream is True: + response = litellm.module_level_client.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=stream, + timeout=timeout, # type: ignore + ) + _response = CustomStreamWrapper( + response.iter_lines(), + model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return _response + ### SYNC COMPLETION + else: + response = litellm.module_level_client.post( + url=completion_url, + headers=headers, + data=json.dumps(data), + timeout=timeout, # type: ignore + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=optional_params.get("stream", False), + logging_obj=logging_obj, # type: ignore + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params=None, + logger_fn=None, + headers={}, + ) -> ModelResponse: + + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.PREDIBASE, + params={"timeout": timeout}, + ) + try: + response = await async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + except httpx.HTTPStatusError as e: + raise PredibaseError( + status_code=e.response.status_code, + message="HTTPStatusError - received status_code={}, error_message={}".format( + e.response.status_code, e.response.text + ), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise PredibaseError( + status_code=500, message="{}".format(str(e)) + ) # don't use verbose_logger.exception, if exception is raised + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + data: dict, + timeout: Union[float, httpx.Timeout], + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ) -> CustomStreamWrapper: + data["stream"] = True + + streamwrapper = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ), + model=model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return streamwrapper + + def embedding(self, *args, **kwargs): + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py new file mode 100644 index 00000000..f5742386 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py @@ -0,0 +1,180 @@ +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union + +from httpx import Headers, Response + +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +from ..common_utils import PredibaseError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class PredibaseConfig(BaseConfig): + """ + Reference: https://docs.predibase.com/user-guide/inference/rest_api + """ + + adapter_id: Optional[str] = None + adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None + best_of: Optional[int] = None + decoder_input_details: Optional[bool] = None + details: bool = True # enables returning logprobs + best of + max_new_tokens: int = ( + 256 # openai default - requests hang if max_new_tokens not given + ) + repetition_penalty: Optional[float] = None + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) + seed: Optional[int] = None + stop: Optional[List[str]] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + truncate: Optional[int] = None + typical_p: Optional[float] = None + watermark: Optional[bool] = None + + def __init__( + self, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + details: Optional[bool] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str): + return [ + "stream", + "temperature", + "max_completion_tokens", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens" or param == "max_completion_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return PredibaseError( + status_code=status_code, message=error_message, headers=headers + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" + ) + + default_headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + if headers is not None and isinstance(headers, dict): + headers = {**default_headers, **headers} + return headers |