about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py300
1 files changed, 300 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py
new file mode 100644
index 00000000..f52eb2ee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/replicate/chat/handler.py
@@ -0,0 +1,300 @@
+import asyncio
+import json
+import time
+from typing import Callable, List, Union
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    _get_httpx_client,
+    get_async_httpx_client,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.utils import CustomStreamWrapper, ModelResponse
+
+from ..common_utils import ReplicateError
+from .transformation import ReplicateConfig
+
+replicate_config = ReplicateConfig()
+
+
+# Function to handle prediction response (streaming)
+def handle_prediction_response_streaming(
+    prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler
+):
+    previous_output = ""
+    output_string = ""
+
+    status = ""
+    while True and (status not in ["succeeded", "failed", "canceled"]):
+        time.sleep(0.5)  # prevent being rate limited by replicate
+        print_verbose(f"replicate: polling endpoint: {prediction_url}")
+        response = http_client.get(prediction_url, headers=headers)
+        if response.status_code == 200:
+            response_data = response.json()
+            status = response_data["status"]
+            if "output" in response_data:
+                try:
+                    output_string = "".join(response_data["output"])
+                except Exception:
+                    raise ReplicateError(
+                        status_code=422,
+                        message="Unable to parse response. Got={}".format(
+                            response_data["output"]
+                        ),
+                        headers=response.headers,
+                    )
+                new_output = output_string[len(previous_output) :]
+                print_verbose(f"New chunk: {new_output}")
+                yield {"output": new_output, "status": status}
+                previous_output = output_string
+            status = response_data["status"]
+            if status == "failed":
+                replicate_error = response_data.get("error", "")
+                raise ReplicateError(
+                    status_code=400,
+                    message=f"Error: {replicate_error}",
+                    headers=response.headers,
+                )
+        else:
+            # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+            print_verbose(
+                f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
+            )
+
+
+# Function to handle prediction response (streaming)
+async def async_handle_prediction_response_streaming(
+    prediction_url,
+    api_token,
+    print_verbose,
+    headers: dict,
+    http_client: AsyncHTTPHandler,
+):
+    previous_output = ""
+    output_string = ""
+
+    status = ""
+    while True and (status not in ["succeeded", "failed", "canceled"]):
+        await asyncio.sleep(0.5)  # prevent being rate limited by replicate
+        print_verbose(f"replicate: polling endpoint: {prediction_url}")
+        response = await http_client.get(prediction_url, headers=headers)
+        if response.status_code == 200:
+            response_data = response.json()
+            status = response_data["status"]
+            if "output" in response_data:
+                try:
+                    output_string = "".join(response_data["output"])
+                except Exception:
+                    raise ReplicateError(
+                        status_code=422,
+                        message="Unable to parse response. Got={}".format(
+                            response_data["output"]
+                        ),
+                        headers=response.headers,
+                    )
+                new_output = output_string[len(previous_output) :]
+                print_verbose(f"New chunk: {new_output}")
+                yield {"output": new_output, "status": status}
+                previous_output = output_string
+            status = response_data["status"]
+            if status == "failed":
+                replicate_error = response_data.get("error", "")
+                raise ReplicateError(
+                    status_code=400,
+                    message=f"Error: {replicate_error}",
+                    headers=response.headers,
+                )
+        else:
+            # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
+            print_verbose(
+                f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
+            )
+
+
+# Main function for prediction completion
+def completion(
+    model: str,
+    messages: list,
+    api_base: str,
+    model_response: ModelResponse,
+    print_verbose: Callable,
+    optional_params: dict,
+    litellm_params: dict,
+    logging_obj,
+    api_key,
+    encoding,
+    custom_prompt_dict={},
+    logger_fn=None,
+    acompletion=None,
+    headers={},
+) -> Union[ModelResponse, CustomStreamWrapper]:
+    headers = replicate_config.validate_environment(
+        api_key=api_key,
+        headers=headers,
+        model=model,
+        messages=messages,
+        optional_params=optional_params,
+    )
+    # Start a prediction and get the prediction URL
+    version_id = replicate_config.model_to_version_id(model)
+    input_data = replicate_config.transform_request(
+        model=model,
+        messages=messages,
+        optional_params=optional_params,
+        litellm_params=litellm_params,
+        headers=headers,
+    )
+
+    if acompletion is not None and acompletion is True:
+        return async_completion(
+            model_response=model_response,
+            model=model,
+            encoding=encoding,
+            messages=messages,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            version_id=version_id,
+            input_data=input_data,
+            api_key=api_key,
+            api_base=api_base,
+            logging_obj=logging_obj,
+            print_verbose=print_verbose,
+            headers=headers,
+        )  # type: ignore
+    ## COMPLETION CALL
+    model_response.created = int(
+        time.time()
+    )  # for pricing this must remain right before calling api
+
+    prediction_url = replicate_config.get_complete_url(
+        api_base=api_base,
+        model=model,
+        optional_params=optional_params,
+        litellm_params=litellm_params,
+    )
+
+    ## COMPLETION CALL
+    httpx_client = _get_httpx_client(
+        params={"timeout": 600.0},
+    )
+    response = httpx_client.post(
+        url=prediction_url,
+        headers=headers,
+        data=json.dumps(input_data),
+    )
+
+    prediction_url = replicate_config.get_prediction_url(response)
+
+    # Handle the prediction response (streaming or non-streaming)
+    if "stream" in optional_params and optional_params["stream"] is True:
+        print_verbose("streaming request")
+        _response = handle_prediction_response_streaming(
+            prediction_url,
+            api_key,
+            print_verbose,
+            headers=headers,
+            http_client=httpx_client,
+        )
+        return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate")  # type: ignore
+    else:
+        for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES):
+            time.sleep(
+                litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry
+            )  # wait to allow response to be generated by replicate - else partial output is generated with status=="processing"
+            response = httpx_client.get(url=prediction_url, headers=headers)
+            if (
+                response.status_code == 200
+                and response.json().get("status") == "processing"
+            ):
+                continue
+            return litellm.ReplicateConfig().transform_response(
+                model=model,
+                raw_response=response,
+                model_response=model_response,
+                logging_obj=logging_obj,
+                api_key=api_key,
+                request_data=input_data,
+                messages=messages,
+                optional_params=optional_params,
+                litellm_params=litellm_params,
+                encoding=encoding,
+            )
+
+    raise ReplicateError(
+        status_code=500,
+        message="No response received from Replicate API after max retries",
+        headers=None,
+    )
+
+
+async def async_completion(
+    model_response: ModelResponse,
+    model: str,
+    messages: List[AllMessageValues],
+    encoding,
+    optional_params: dict,
+    litellm_params: dict,
+    version_id,
+    input_data,
+    api_key,
+    api_base,
+    logging_obj,
+    print_verbose,
+    headers: dict,
+) -> Union[ModelResponse, CustomStreamWrapper]:
+
+    prediction_url = replicate_config.get_complete_url(
+        api_base=api_base,
+        model=model,
+        optional_params=optional_params,
+        litellm_params=litellm_params,
+    )
+    async_handler = get_async_httpx_client(
+        llm_provider=litellm.LlmProviders.REPLICATE,
+        params={"timeout": 600.0},
+    )
+    response = await async_handler.post(
+        url=prediction_url, headers=headers, data=json.dumps(input_data)
+    )
+    prediction_url = replicate_config.get_prediction_url(response)
+
+    if "stream" in optional_params and optional_params["stream"] is True:
+        _response = async_handle_prediction_response_streaming(
+            prediction_url,
+            api_key,
+            print_verbose,
+            headers=headers,
+            http_client=async_handler,
+        )
+        return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate")  # type: ignore
+
+    for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES):
+        await asyncio.sleep(
+            litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry
+        )  # wait to allow response to be generated by replicate - else partial output is generated with status=="processing"
+        response = await async_handler.get(url=prediction_url, headers=headers)
+        if (
+            response.status_code == 200
+            and response.json().get("status") == "processing"
+        ):
+            continue
+        return litellm.ReplicateConfig().transform_response(
+            model=model,
+            raw_response=response,
+            model_response=model_response,
+            logging_obj=logging_obj,
+            api_key=api_key,
+            request_data=input_data,
+            messages=messages,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            encoding=encoding,
+        )
+    # Add a fallback return if no response is received after max retries
+    raise ReplicateError(
+        status_code=500,
+        message="No response received from Replicate API after max retries",
+        headers=None,
+    )