diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/baseten.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/baseten.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/baseten.py | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/baseten.py b/.venv/lib/python3.12/site-packages/litellm/llms/baseten.py new file mode 100644 index 00000000..e1d513d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/baseten.py @@ -0,0 +1,172 @@ +import json +import time +from typing import Callable + +import litellm +from litellm.types.utils import ModelResponse, Usage + + +class BasetenError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Api-Key {api_key}" + return headers + + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url_fragment_1 = "https://app.baseten.co/models/" + completion_url_fragment_2 = "/predict" + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + else: + prompt += f"{message['content']}" + data = { + "inputs": prompt, + "prompt": prompt, + "parameters": optional_params, + "stream": ( + True + if "stream" in optional_params and optional_params["stream"] is True + else False + ), + } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = litellm.module_level_client.post( + completion_url_fragment_1 + model + completion_url_fragment_2, + headers=headers, + data=json.dumps(data), + stream=( + True + if "stream" in optional_params and optional_params["stream"] is True + else False + ), + ) + if "text/event-stream" in response.headers["Content-Type"] or ( + "stream" in optional_params and optional_params["stream"] is True + ): + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise BasetenError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + if "model_output" in completion_response: + if ( + isinstance(completion_response["model_output"], dict) + and "data" in completion_response["model_output"] + and isinstance(completion_response["model_output"]["data"], list) + ): + model_response.choices[0].message.content = completion_response[ # type: ignore + "model_output" + ][ + "data" + ][ + 0 + ] + elif isinstance(completion_response["model_output"], str): + model_response.choices[0].message.content = completion_response[ # type: ignore + "model_output" + ] + elif "completion" in completion_response and isinstance( + completion_response["completion"], str + ): + model_response.choices[0].message.content = completion_response[ # type: ignore + "completion" + ] + elif isinstance(completion_response, list) and len(completion_response) > 0: + if "generated_text" not in completion_response: + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code, + ) + model_response.choices[0].message.content = completion_response[0][ # type: ignore + "generated_text" + ] + ## GETTING LOGPROBS + if ( + "details" in completion_response[0] + and "tokens" in completion_response[0]["details"] + ): + model_response.choices[0].finish_reason = completion_response[0][ + "details" + ]["finish_reason"] + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response.choices[0].logprobs = sum_logprob # type: ignore + else: + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + setattr(model_response, "usage", usage) + return model_response + + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass |