aboutsummaryrefslogtreecommitdiff
import json
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union

import httpx

from litellm.litellm_core_utils.prompt_templates.common_utils import (
    convert_content_list_to_str,
)
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
    ChatCompletionToolCallChunk,
    ChatCompletionUsageBlock,
    Choices,
    GenericStreamingChunk,
    Message,
    ModelResponse,
    Usage,
)
from litellm.utils import token_counter

from ..common_utils import ClarifaiError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

    LoggingClass = LiteLLMLoggingObj
else:
    LoggingClass = Any


class ClarifaiConfig(BaseConfig):
    """
    Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat
    """

    max_tokens: Optional[int] = None
    temperature: Optional[int] = None
    top_k: Optional[int] = None

    def __init__(
        self,
        max_tokens: Optional[int] = None,
        temperature: Optional[int] = None,
        top_k: Optional[int] = None,
    ) -> None:
        locals_ = locals().copy()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return super().get_config()

    def get_supported_openai_params(self, model: str) -> list:
        return [
            "temperature",
            "max_tokens",
        ]

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
    ) -> dict:
        for param, value in non_default_params.items():
            if param == "temperature":
                optional_params["temperature"] = value
            elif param == "max_tokens":
                optional_params["max_tokens"] = value

        return optional_params

    def _completions_to_model(self, prompt: str, optional_params: dict) -> dict:
        params = {}
        if temperature := optional_params.get("temperature"):
            params["temperature"] = temperature
        if max_tokens := optional_params.get("max_tokens"):
            params["max_tokens"] = max_tokens
        return {
            "inputs": [{"data": {"text": {"raw": prompt}}}],
            "model": {"output_info": {"params": params}},
        }

    def _convert_model_to_url(self, model: str, api_base: str):
        user_id, app_id, model_id = model.split(".")
        return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"

    def transform_request(
        self,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        headers: dict,
    ) -> dict:
        prompt = " ".join(convert_content_list_to_str(message) for message in messages)

        ## Load Config
        config = self.get_config()
        for k, v in config.items():
            if k not in optional_params:
                optional_params[k] = v

        data = self._completions_to_model(
            prompt=prompt, optional_params=optional_params
        )

        return data

    def validate_environment(
        self,
        headers: dict,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        headers = {
            "accept": "application/json",
            "content-type": "application/json",
        }

        if api_key:
            headers["Authorization"] = f"Bearer {api_key}"
        return headers

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
    ) -> BaseLLMException:
        return ClarifaiError(message=error_message, status_code=status_code)

    def transform_response(
        self,
        model: str,
        raw_response: httpx.Response,
        model_response: ModelResponse,
        logging_obj: LoggingClass,
        request_data: dict,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        encoding: str,
        api_key: Optional[str] = None,
        json_mode: Optional[bool] = None,
    ) -> ModelResponse:
        logging_obj.post_call(
            input=messages,
            api_key=api_key,
            original_response=raw_response.text,
            additional_args={"complete_input_dict": request_data},
        )
        ## RESPONSE OBJECT
        try:
            completion_response = raw_response.json()
        except httpx.HTTPStatusError as e:
            raise ClarifaiError(
                message=str(e),
                status_code=raw_response.status_code,
            )
        except Exception as e:
            raise ClarifaiError(
                message=str(e),
                status_code=422,
            )
        # print(completion_response)
        try:
            choices_list = []
            for idx, item in enumerate(completion_response["outputs"]):
                if len(item["data"]["text"]["raw"]) > 0:
                    message_obj = Message(content=item["data"]["text"]["raw"])
                else:
                    message_obj = Message(content=None)
                choice_obj = Choices(
                    finish_reason="stop",
                    index=idx + 1,  # check
                    message=message_obj,
                )
                choices_list.append(choice_obj)
            model_response.choices = choices_list  # type: ignore

        except Exception as e:
            raise ClarifaiError(
                message=str(e),
                status_code=422,
            )

        # Calculate Usage
        prompt_tokens = token_counter(model=model, messages=messages)
        completion_tokens = len(
            encoding.encode(model_response["choices"][0]["message"].get("content"))
        )
        model_response.model = model
        setattr(
            model_response,
            "usage",
            Usage(
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
            ),
        )
        return model_response

    def get_model_response_iterator(
        self,
        streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
        sync_stream: bool,
        json_mode: Optional[bool] = False,
    ) -> Any:
        return ClarifaiModelResponseIterator(
            model_response=streaming_response,
            json_mode=json_mode,
        )


class ClarifaiModelResponseIterator(FakeStreamResponseIterator):
    def __init__(
        self,
        model_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
        json_mode: Optional[bool] = False,
    ):
        super().__init__(
            model_response=model_response,
            json_mode=json_mode,
        )

    def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
        try:
            text = ""
            tool_use: Optional[ChatCompletionToolCallChunk] = None
            is_finished = False
            finish_reason = ""
            usage: Optional[ChatCompletionUsageBlock] = None
            provider_specific_fields = None

            text = (
                chunk.get("outputs", "")[0]
                .get("data", "")
                .get("text", "")
                .get("raw", "")
            )

            index: int = 0

            return GenericStreamingChunk(
                text=text,
                tool_use=tool_use,
                is_finished=is_finished,
                finish_reason=finish_reason,
                usage=usage,
                index=index,
                provider_specific_fields=provider_specific_fields,
            )
        except json.JSONDecodeError:
            raise ValueError(f"Failed to decode JSON from chunk: {chunk}")