aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py472
1 files changed, 472 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py
new file mode 100644
index 00000000..43f4b067
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/handler.py
@@ -0,0 +1,472 @@
+# What is this?
+## Controller file for Predibase Integration - https://predibase.com/
+
+import json
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Union
+
+import httpx # type: ignore
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.litellm_logging
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ custom_prompt,
+ prompt_factory,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.types.utils import LiteLLMLoggingBaseClass
+from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
+
+from ..common_utils import PredibaseError
+
+
+async def make_call(
+ client: AsyncHTTPHandler,
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+ timeout: Optional[Union[float, httpx.Timeout]],
+):
+ response = await client.post(
+ api_base, headers=headers, data=data, stream=True, timeout=timeout
+ )
+
+ if response.status_code != 200:
+ raise PredibaseError(status_code=response.status_code, message=response.text)
+
+ completion_stream = response.aiter_lines()
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=completion_stream, # Pass the completion stream for logging
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+class PredibaseChatCompletion:
+ def __init__(self) -> None:
+ super().__init__()
+
+ def output_parser(self, generated_text: str):
+ """
+ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
+
+ Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
+ """
+ chat_template_tokens = [
+ "<|assistant|>",
+ "<|system|>",
+ "<|user|>",
+ "<s>",
+ "</s>",
+ ]
+ for token in chat_template_tokens:
+ if generated_text.strip().startswith(token):
+ generated_text = generated_text.replace(token, "", 1)
+ if generated_text.endswith(token):
+ generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
+ return generated_text
+
+ def process_response( # noqa: PLR0915
+ self,
+ model: str,
+ response: httpx.Response,
+ model_response: ModelResponse,
+ stream: bool,
+ logging_obj: LiteLLMLoggingBaseClass,
+ optional_params: dict,
+ api_key: str,
+ data: Union[dict, str],
+ messages: list,
+ print_verbose,
+ encoding,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ print_verbose(f"raw model_response: {response.text}")
+ ## RESPONSE OBJECT
+ try:
+ completion_response = response.json()
+ except Exception:
+ raise PredibaseError(message=response.text, status_code=422)
+ if "error" in completion_response:
+ raise PredibaseError(
+ message=str(completion_response["error"]),
+ status_code=response.status_code,
+ )
+ else:
+ if not isinstance(completion_response, dict):
+ raise PredibaseError(
+ status_code=422,
+ message=f"'completion_response' is not a dictionary - {completion_response}",
+ )
+ elif "generated_text" not in completion_response:
+ raise PredibaseError(
+ status_code=422,
+ message=f"'generated_text' is not a key response dictionary - {completion_response}",
+ )
+ if len(completion_response["generated_text"]) > 0:
+ model_response.choices[0].message.content = self.output_parser( # type: ignore
+ completion_response["generated_text"]
+ )
+ ## GETTING LOGPROBS + FINISH REASON
+ if (
+ "details" in completion_response
+ and "tokens" in completion_response["details"]
+ ):
+ model_response.choices[0].finish_reason = map_finish_reason(
+ completion_response["details"]["finish_reason"]
+ )
+ sum_logprob = 0
+ for token in completion_response["details"]["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ setattr(
+ model_response.choices[0].message, # type: ignore
+ "_logprob",
+ sum_logprob, # [TODO] move this to using the actual logprobs
+ )
+ if "best_of" in optional_params and optional_params["best_of"] > 1:
+ if (
+ "details" in completion_response
+ and "best_of_sequences" in completion_response["details"]
+ ):
+ choices_list = []
+ for idx, item in enumerate(
+ completion_response["details"]["best_of_sequences"]
+ ):
+ sum_logprob = 0
+ for token in item["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ if len(item["generated_text"]) > 0:
+ message_obj = Message(
+ content=self.output_parser(item["generated_text"]),
+ logprobs=sum_logprob,
+ )
+ else:
+ message_obj = Message(content=None)
+ choice_obj = Choices(
+ finish_reason=map_finish_reason(item["finish_reason"]),
+ index=idx + 1,
+ message=message_obj,
+ )
+ choices_list.append(choice_obj)
+ model_response.choices.extend(choices_list)
+
+ ## CALCULATING USAGE
+ prompt_tokens = 0
+ try:
+ prompt_tokens = litellm.token_counter(messages=messages)
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ output_text = model_response["choices"][0]["message"].get("content", "")
+ if output_text is not None and len(output_text) > 0:
+ completion_tokens = 0
+ try:
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ ) ##[TODO] use a model-specific tokenizer
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ else:
+ completion_tokens = 0
+
+ total_tokens = prompt_tokens + completion_tokens
+
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ model_response.usage = usage # type: ignore
+
+ ## RESPONSE HEADERS
+ predibase_headers = response.headers
+ response_headers = {}
+ for k, v in predibase_headers.items():
+ if k.startswith("x-"):
+ response_headers["llm_provider-{}".format(k)] = v
+
+ model_response._hidden_params["additional_headers"] = response_headers
+
+ return model_response
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_prompt_dict: dict,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key: str,
+ logging_obj,
+ optional_params: dict,
+ tenant_id: str,
+ timeout: Union[float, httpx.Timeout],
+ acompletion=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers: dict = {},
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ headers = litellm.PredibaseConfig().validate_environment(
+ api_key=api_key,
+ headers=headers,
+ messages=messages,
+ optional_params=optional_params,
+ model=model,
+ )
+ completion_url = ""
+ input_text = ""
+ base_url = "https://serving.app.predibase.com"
+
+ if "https" in model:
+ completion_url = model
+ elif api_base:
+ base_url = api_base
+ elif "PREDIBASE_API_BASE" in os.environ:
+ base_url = os.getenv("PREDIBASE_API_BASE", "")
+
+ completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
+
+ if optional_params.get("stream", False) is True:
+ completion_url += "/generate_stream"
+ else:
+ completion_url += "/generate"
+
+ if model in custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details["roles"],
+ initial_prompt_value=model_prompt_details["initial_prompt_value"],
+ final_prompt_value=model_prompt_details["final_prompt_value"],
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+
+ ## Load Config
+ config = litellm.PredibaseConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ stream = optional_params.pop("stream", False)
+
+ data = {
+ "inputs": prompt,
+ "parameters": optional_params,
+ }
+ input_text = prompt
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input_text,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": headers,
+ "api_base": completion_url,
+ "acompletion": acompletion,
+ },
+ )
+ ## COMPLETION CALL
+ if acompletion is True:
+ ### ASYNC STREAMING
+ if stream is True:
+ return self.async_streaming(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=completion_url,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ timeout=timeout,
+ ) # type: ignore
+ else:
+ ### ASYNC COMPLETION
+ return self.async_completion(
+ model=model,
+ messages=messages,
+ data=data,
+ api_base=completion_url,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=False,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ headers=headers,
+ timeout=timeout,
+ ) # type: ignore
+
+ ### SYNC STREAMING
+ if stream is True:
+ response = litellm.module_level_client.post(
+ completion_url,
+ headers=headers,
+ data=json.dumps(data),
+ stream=stream,
+ timeout=timeout, # type: ignore
+ )
+ _response = CustomStreamWrapper(
+ response.iter_lines(),
+ model,
+ custom_llm_provider="predibase",
+ logging_obj=logging_obj,
+ )
+ return _response
+ ### SYNC COMPLETION
+ else:
+ response = litellm.module_level_client.post(
+ url=completion_url,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout, # type: ignore
+ )
+ return self.process_response(
+ model=model,
+ response=response,
+ model_response=model_response,
+ stream=optional_params.get("stream", False),
+ logging_obj=logging_obj, # type: ignore
+ optional_params=optional_params,
+ api_key=api_key,
+ data=data,
+ messages=messages,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ )
+
+ async def async_completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ stream,
+ data: dict,
+ optional_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ ) -> ModelResponse:
+
+ async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.PREDIBASE,
+ params={"timeout": timeout},
+ )
+ try:
+ response = await async_handler.post(
+ api_base, headers=headers, data=json.dumps(data)
+ )
+ except httpx.HTTPStatusError as e:
+ raise PredibaseError(
+ status_code=e.response.status_code,
+ message="HTTPStatusError - received status_code={}, error_message={}".format(
+ e.response.status_code, e.response.text
+ ),
+ )
+ except Exception as e:
+ for exception in litellm.LITELLM_EXCEPTION_TYPES:
+ if isinstance(e, exception):
+ raise e
+ raise PredibaseError(
+ status_code=500, message="{}".format(str(e))
+ ) # don't use verbose_logger.exception, if exception is raised
+ return self.process_response(
+ model=model,
+ response=response,
+ model_response=model_response,
+ stream=stream,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ data=data,
+ messages=messages,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ encoding=encoding,
+ )
+
+ async def async_streaming(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ api_key,
+ logging_obj,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ optional_params=None,
+ litellm_params=None,
+ logger_fn=None,
+ headers={},
+ ) -> CustomStreamWrapper:
+ data["stream"] = True
+
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ api_base=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ ),
+ model=model,
+ custom_llm_provider="predibase",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ def embedding(self, *args, **kwargs):
+ pass