diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py | 472 |
1 files changed, 472 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py new file mode 100644 index 00000000..43f4b067 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py @@ -0,0 +1,472 @@ +# What is this? +## Controller file for Predibase Integration - https://predibase.com/ + +import json +import os +import time +from functools import partial +from typing import Callable, Optional, Union + +import httpx # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) +from litellm.types.utils import LiteLLMLoggingBaseClass +from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage + +from ..common_utils import PredibaseError + + +async def make_call( + client: AsyncHTTPHandler, + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, + timeout: Optional[Union[float, httpx.Timeout]], +): + response = await client.post( + api_base, headers=headers, data=data, stream=True, timeout=timeout + ) + + if response.status_code != 200: + raise PredibaseError(status_code=response.status_code, message=response.text) + + completion_stream = response.aiter_lines() + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class PredibaseChatCompletion: + def __init__(self) -> None: + super().__init__() + + def output_parser(self, generated_text: str): + """ + Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. + + Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 + """ + chat_template_tokens = [ + "<|assistant|>", + "<|system|>", + "<|user|>", + "<s>", + "</s>", + ] + for token in chat_template_tokens: + if generated_text.strip().startswith(token): + generated_text = generated_text.replace(token, "", 1) + if generated_text.endswith(token): + generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] + return generated_text + + def process_response( # noqa: PLR0915 + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: LiteLLMLoggingBaseClass, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: list, + print_verbose, + encoding, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + try: + completion_response = response.json() + except Exception: + raise PredibaseError(message=response.text, status_code=422) + if "error" in completion_response: + raise PredibaseError( + message=str(completion_response["error"]), + status_code=response.status_code, + ) + else: + if not isinstance(completion_response, dict): + raise PredibaseError( + status_code=422, + message=f"'completion_response' is not a dictionary - {completion_response}", + ) + elif "generated_text" not in completion_response: + raise PredibaseError( + status_code=422, + message=f"'generated_text' is not a key response dictionary - {completion_response}", + ) + if len(completion_response["generated_text"]) > 0: + model_response.choices[0].message.content = self.output_parser( # type: ignore + completion_response["generated_text"] + ) + ## GETTING LOGPROBS + FINISH REASON + if ( + "details" in completion_response + and "tokens" in completion_response["details"] + ): + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["details"]["finish_reason"] + ) + sum_logprob = 0 + for token in completion_response["details"]["tokens"]: + if token["logprob"] is not None: + sum_logprob += token["logprob"] + setattr( + model_response.choices[0].message, # type: ignore + "_logprob", + sum_logprob, # [TODO] move this to using the actual logprobs + ) + if "best_of" in optional_params and optional_params["best_of"] > 1: + if ( + "details" in completion_response + and "best_of_sequences" in completion_response["details"] + ): + choices_list = [] + for idx, item in enumerate( + completion_response["details"]["best_of_sequences"] + ): + sum_logprob = 0 + for token in item["tokens"]: + if token["logprob"] is not None: + sum_logprob += token["logprob"] + if len(item["generated_text"]) > 0: + message_obj = Message( + content=self.output_parser(item["generated_text"]), + logprobs=sum_logprob, + ) + else: + message_obj = Message(content=None) + choice_obj = Choices( + finish_reason=map_finish_reason(item["finish_reason"]), + index=idx + 1, + message=message_obj, + ) + choices_list.append(choice_obj) + model_response.choices.extend(choices_list) + + ## CALCULATING USAGE + prompt_tokens = 0 + try: + prompt_tokens = litellm.token_counter(messages=messages) + except Exception: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + output_text = model_response["choices"][0]["message"].get("content", "") + if output_text is not None and len(output_text) > 0: + completion_tokens = 0 + try: + completion_tokens = len( + encoding.encode( + model_response["choices"][0]["message"].get("content", "") + ) + ) ##[TODO] use a model-specific tokenizer + except Exception: + # this should remain non blocking we should not block a response returning if calculating usage fails + pass + else: + completion_tokens = 0 + + total_tokens = prompt_tokens + completion_tokens + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + model_response.usage = usage # type: ignore + + ## RESPONSE HEADERS + predibase_headers = response.headers + response_headers = {} + for k, v in predibase_headers.items(): + if k.startswith("x-"): + response_headers["llm_provider-{}".format(k)] = v + + model_response._hidden_params["additional_headers"] = response_headers + + return model_response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key: str, + logging_obj, + optional_params: dict, + tenant_id: str, + timeout: Union[float, httpx.Timeout], + acompletion=None, + litellm_params=None, + logger_fn=None, + headers: dict = {}, + ) -> Union[ModelResponse, CustomStreamWrapper]: + headers = litellm.PredibaseConfig().validate_environment( + api_key=api_key, + headers=headers, + messages=messages, + optional_params=optional_params, + model=model, + ) + completion_url = "" + input_text = "" + base_url = "https://serving.app.predibase.com" + + if "https" in model: + completion_url = model + elif api_base: + base_url = api_base + elif "PREDIBASE_API_BASE" in os.environ: + base_url = os.getenv("PREDIBASE_API_BASE", "") + + completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" + + if optional_params.get("stream", False) is True: + completion_url += "/generate_stream" + else: + completion_url += "/generate" + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + ## Load Config + config = litellm.PredibaseConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + stream = optional_params.pop("stream", False) + + data = { + "inputs": prompt, + "parameters": optional_params, + } + input_text = prompt + ## LOGGING + logging_obj.pre_call( + input=input_text, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + "acompletion": acompletion, + }, + ) + ## COMPLETION CALL + if acompletion is True: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + else: + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=completion_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + api_key=api_key, + logging_obj=logging_obj, + optional_params=optional_params, + stream=False, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=headers, + timeout=timeout, + ) # type: ignore + + ### SYNC STREAMING + if stream is True: + response = litellm.module_level_client.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=stream, + timeout=timeout, # type: ignore + ) + _response = CustomStreamWrapper( + response.iter_lines(), + model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return _response + ### SYNC COMPLETION + else: + response = litellm.module_level_client.post( + url=completion_url, + headers=headers, + data=json.dumps(data), + timeout=timeout, # type: ignore + ) + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=optional_params.get("stream", False), + logging_obj=logging_obj, # type: ignore + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + stream, + data: dict, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params=None, + logger_fn=None, + headers={}, + ) -> ModelResponse: + + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.PREDIBASE, + params={"timeout": timeout}, + ) + try: + response = await async_handler.post( + api_base, headers=headers, data=json.dumps(data) + ) + except httpx.HTTPStatusError as e: + raise PredibaseError( + status_code=e.response.status_code, + message="HTTPStatusError - received status_code={}, error_message={}".format( + e.response.status_code, e.response.text + ), + ) + except Exception as e: + for exception in litellm.LITELLM_EXCEPTION_TYPES: + if isinstance(e, exception): + raise e + raise PredibaseError( + status_code=500, message="{}".format(str(e)) + ) # don't use verbose_logger.exception, if exception is raised + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + data: dict, + timeout: Union[float, httpx.Timeout], + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, + ) -> CustomStreamWrapper: + data["stream"] = True + + streamwrapper = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + api_base=api_base, + headers=headers, + data=json.dumps(data), + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + ), + model=model, + custom_llm_provider="predibase", + logging_obj=logging_obj, + ) + return streamwrapper + + def embedding(self, *args, **kwargs): + pass |