diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py | 300 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/transformation.py | 319 |
2 files changed, 619 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py new file mode 100644 index 00000000..f52eb2ee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py @@ -0,0 +1,300 @@ +import asyncio +import json +import time +from typing import Callable, List, Union + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.utils import CustomStreamWrapper, ModelResponse + +from ..common_utils import ReplicateError +from .transformation import ReplicateConfig + +replicate_config = ReplicateConfig() + + +# Function to handle prediction response (streaming) +def handle_prediction_response_streaming( + prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler +): + previous_output = "" + output_string = "" + + status = "" + while True and (status not in ["succeeded", "failed", "canceled"]): + time.sleep(0.5) # prevent being rate limited by replicate + print_verbose(f"replicate: polling endpoint: {prediction_url}") + response = http_client.get(prediction_url, headers=headers) + if response.status_code == 200: + response_data = response.json() + status = response_data["status"] + if "output" in response_data: + try: + output_string = "".join(response_data["output"]) + except Exception: + raise ReplicateError( + status_code=422, + message="Unable to parse response. Got={}".format( + response_data["output"] + ), + headers=response.headers, + ) + new_output = output_string[len(previous_output) :] + print_verbose(f"New chunk: {new_output}") + yield {"output": new_output, "status": status} + previous_output = output_string + status = response_data["status"] + if status == "failed": + replicate_error = response_data.get("error", "") + raise ReplicateError( + status_code=400, + message=f"Error: {replicate_error}", + headers=response.headers, + ) + else: + # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" + print_verbose( + f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}" + ) + + +# Function to handle prediction response (streaming) +async def async_handle_prediction_response_streaming( + prediction_url, + api_token, + print_verbose, + headers: dict, + http_client: AsyncHTTPHandler, +): + previous_output = "" + output_string = "" + + status = "" + while True and (status not in ["succeeded", "failed", "canceled"]): + await asyncio.sleep(0.5) # prevent being rate limited by replicate + print_verbose(f"replicate: polling endpoint: {prediction_url}") + response = await http_client.get(prediction_url, headers=headers) + if response.status_code == 200: + response_data = response.json() + status = response_data["status"] + if "output" in response_data: + try: + output_string = "".join(response_data["output"]) + except Exception: + raise ReplicateError( + status_code=422, + message="Unable to parse response. Got={}".format( + response_data["output"] + ), + headers=response.headers, + ) + new_output = output_string[len(previous_output) :] + print_verbose(f"New chunk: {new_output}") + yield {"output": new_output, "status": status} + previous_output = output_string + status = response_data["status"] + if status == "failed": + replicate_error = response_data.get("error", "") + raise ReplicateError( + status_code=400, + message=f"Error: {replicate_error}", + headers=response.headers, + ) + else: + # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" + print_verbose( + f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}" + ) + + +# Main function for prediction completion +def completion( + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + optional_params: dict, + litellm_params: dict, + logging_obj, + api_key, + encoding, + custom_prompt_dict={}, + logger_fn=None, + acompletion=None, + headers={}, +) -> Union[ModelResponse, CustomStreamWrapper]: + headers = replicate_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + ) + # Start a prediction and get the prediction URL + version_id = replicate_config.model_to_version_id(model) + input_data = replicate_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + if acompletion is not None and acompletion is True: + return async_completion( + model_response=model_response, + model=model, + encoding=encoding, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + version_id=version_id, + input_data=input_data, + api_key=api_key, + api_base=api_base, + logging_obj=logging_obj, + print_verbose=print_verbose, + headers=headers, + ) # type: ignore + ## COMPLETION CALL + model_response.created = int( + time.time() + ) # for pricing this must remain right before calling api + + prediction_url = replicate_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + ## COMPLETION CALL + httpx_client = _get_httpx_client( + params={"timeout": 600.0}, + ) + response = httpx_client.post( + url=prediction_url, + headers=headers, + data=json.dumps(input_data), + ) + + prediction_url = replicate_config.get_prediction_url(response) + + # Handle the prediction response (streaming or non-streaming) + if "stream" in optional_params and optional_params["stream"] is True: + print_verbose("streaming request") + _response = handle_prediction_response_streaming( + prediction_url, + api_key, + print_verbose, + headers=headers, + http_client=httpx_client, + ) + return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore + else: + for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES): + time.sleep( + litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry + ) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing" + response = httpx_client.get(url=prediction_url, headers=headers) + if ( + response.status_code == 200 + and response.json().get("status") == "processing" + ): + continue + return litellm.ReplicateConfig().transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=input_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + raise ReplicateError( + status_code=500, + message="No response received from Replicate API after max retries", + headers=None, + ) + + +async def async_completion( + model_response: ModelResponse, + model: str, + messages: List[AllMessageValues], + encoding, + optional_params: dict, + litellm_params: dict, + version_id, + input_data, + api_key, + api_base, + logging_obj, + print_verbose, + headers: dict, +) -> Union[ModelResponse, CustomStreamWrapper]: + + prediction_url = replicate_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.REPLICATE, + params={"timeout": 600.0}, + ) + response = await async_handler.post( + url=prediction_url, headers=headers, data=json.dumps(input_data) + ) + prediction_url = replicate_config.get_prediction_url(response) + + if "stream" in optional_params and optional_params["stream"] is True: + _response = async_handle_prediction_response_streaming( + prediction_url, + api_key, + print_verbose, + headers=headers, + http_client=async_handler, + ) + return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore + + for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES): + await asyncio.sleep( + litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry + ) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing" + response = await async_handler.get(url=prediction_url, headers=headers) + if ( + response.status_code == 200 + and response.json().get("status") == "processing" + ): + continue + return litellm.ReplicateConfig().transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=input_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + # Add a fallback return if no response is received after max retries + raise ReplicateError( + status_code=500, + message="No response received from Replicate API after max retries", + headers=None, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/transformation.py new file mode 100644 index 00000000..75cfe6ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/transformation.py @@ -0,0 +1,319 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, Usage +from litellm.utils import token_counter + +from ..common_utils import ReplicateError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any + + +class ReplicateConfig(BaseConfig): + """ + Reference: https://replicate.com/meta/llama-2-70b-chat/api + - `prompt` (string): The prompt to send to the model. + + - `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`. + + - `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`. + + - `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`. + + - `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`. + + - `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`. + + - `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`. + + - `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'. + + - `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed. + + - `debug` (boolean): If set to `True`, it provides debugging output in logs. + + Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models. + """ + + system_prompt: Optional[str] = None + max_new_tokens: Optional[int] = None + min_new_tokens: Optional[int] = None + temperature: Optional[int] = None + top_p: Optional[int] = None + top_k: Optional[int] = None + stop_sequences: Optional[str] = None + seed: Optional[int] = None + debug: Optional[bool] = None + + def __init__( + self, + system_prompt: Optional[str] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + top_k: Optional[int] = None, + stop_sequences: Optional[str] = None, + seed: Optional[int] = None, + debug: Optional[bool] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> list: + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "seed", + "tools", + "tool_choice", + "functions", + "function_call", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "stream": + optional_params["stream"] = value + if param == "max_tokens": + if "vicuna" in model or "flan" in model: + optional_params["max_length"] = value + elif "meta/codellama-13b" in model: + optional_params["max_tokens"] = value + else: + optional_params["max_new_tokens"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop_sequences"] = value + + return optional_params + + # Function to extract version ID from model string + def model_to_version_id(self, model: str) -> str: + if ":" in model: + split_model = model.split(":") + return split_model[1] + return model + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return ReplicateError( + status_code=status_code, message=error_message, headers=headers + ) + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + version_id = self.model_to_version_id(model) + base_url = api_base + if "deployments" in version_id: + version_id = version_id.replace("deployments/", "") + base_url = f"https://api.replicate.com/v1/deployments/{version_id}" + else: # assume it's a model + base_url = f"https://api.replicate.com/v1/models/{version_id}" + + base_url = f"{base_url}/predictions" + return base_url + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + ## Load Config + config = litellm.ReplicateConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + system_prompt = None + if optional_params is not None and "supports_system_prompt" in optional_params: + supports_sys_prompt = optional_params.pop("supports_system_prompt") + else: + supports_sys_prompt = False + + if supports_sys_prompt: + for i in range(len(messages)): + if messages[i]["role"] == "system": + first_sys_message = messages.pop(i) + system_prompt = convert_content_list_to_str(first_sys_message) + break + + if model in litellm.custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = litellm.custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) + + if prompt is None or not isinstance(prompt, str): + raise ReplicateError( + status_code=400, + message="LiteLLM Error - prompt is not a string - {}".format(prompt), + headers={}, + ) + + # If system prompt is supported, and a system prompt is provided, use it + if system_prompt is not None: + input_data = { + "prompt": prompt, + "system_prompt": system_prompt, + **optional_params, + } + # Otherwise, use the prompt as is + else: + input_data = {"prompt": prompt, **optional_params} + + version_id = self.model_to_version_id(model) + request_data: dict = {"input": input_data} + if ":" in version_id and len(version_id) > 64: + model_parts = version_id.split(":") + if ( + len(model_parts) > 1 and len(model_parts[1]) == 64 + ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + request_data["version"] = model_parts[1] + + return request_data + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LoggingClass, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=raw_response.text, + additional_args={"complete_input_dict": request_data}, + ) + raw_response_json = raw_response.json() + if raw_response_json.get("status") != "succeeded": + raise ReplicateError( + status_code=422, + message="LiteLLM Error - prediction not succeeded - {}".format( + raw_response_json + ), + headers=raw_response.headers, + ) + outputs = raw_response_json.get("output", []) + response_str = "".join(outputs) + if len(response_str) == 0: # edge case, where result from replicate is empty + response_str = " " + + ## Building RESPONSE OBJECT + if len(response_str) >= 1: + model_response.choices[0].message.content = response_str # type: ignore + + # Calculate usage + prompt_tokens = token_counter(model=model, messages=messages) + completion_tokens = token_counter( + model=model, + text=response_str, + count_response_tokens=True, + ) + model_response.model = "replicate/" + model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + + def get_prediction_url(self, response: httpx.Response) -> str: + """ + response json: { + ..., + "urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"} + } + """ + response_json = response.json() + prediction_url = response_json.get("urls", {}).get("get") + if prediction_url is None: + raise ReplicateError( + status_code=400, + message="LiteLLM Error - prediction url is None - {}".format( + response_json + ), + headers=response.headers, + ) + return prediction_url + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + headers = { + "Authorization": f"Token {api_key}", + "Content-Type": "application/json", + } + return headers |