aboutsummaryrefslogtreecommitdiff
# What is this?
## handler file for TextCompletionCodestral Integration - https://codestral.com/

import json
from functools import partial
from typing import Callable, List, Optional, Union

import httpx  # type: ignore

import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.prompt_templates.factory import (
    custom_prompt,
    prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import (
    AsyncHTTPHandler,
    get_async_httpx_client,
)
from litellm.types.utils import TextChoices
from litellm.utils import CustomStreamWrapper, TextCompletionResponse


class TextCompletionCodestralError(Exception):
    def __init__(
        self,
        status_code,
        message,
        request: Optional[httpx.Request] = None,
        response: Optional[httpx.Response] = None,
    ):
        self.status_code = status_code
        self.message = message
        if request is not None:
            self.request = request
        else:
            self.request = httpx.Request(
                method="POST",
                url="https://docs.codestral.com/user-guide/inference/rest_api",
            )
        if response is not None:
            self.response = response
        else:
            self.response = httpx.Response(
                status_code=status_code, request=self.request
            )
        super().__init__(
            self.message
        )  # Call the base class constructor with the parameters it needs


async def make_call(
    client: AsyncHTTPHandler,
    api_base: str,
    headers: dict,
    data: str,
    model: str,
    messages: list,
    logging_obj,
):
    response = await client.post(api_base, headers=headers, data=data, stream=True)

    if response.status_code != 200:
        raise TextCompletionCodestralError(
            status_code=response.status_code, message=response.text
        )

    completion_stream = response.aiter_lines()
    # LOGGING
    logging_obj.post_call(
        input=messages,
        api_key="",
        original_response=completion_stream,  # Pass the completion stream for logging
        additional_args={"complete_input_dict": data},
    )

    return completion_stream


class CodestralTextCompletion:
    def __init__(self) -> None:
        super().__init__()

    def _validate_environment(
        self,
        api_key: Optional[str],
        user_headers: dict,
    ) -> dict:
        if api_key is None:
            raise ValueError(
                "Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
            )
        headers = {
            "content-type": "application/json",
            "Authorization": "Bearer {}".format(api_key),
        }
        if user_headers is not None and isinstance(user_headers, dict):
            headers = {**headers, **user_headers}
        return headers

    def output_parser(self, generated_text: str):
        """
        Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.

        Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
        """
        chat_template_tokens = [
            "<|assistant|>",
            "<|system|>",
            "<|user|>",
            "<s>",
            "</s>",
        ]
        for token in chat_template_tokens:
            if generated_text.strip().startswith(token):
                generated_text = generated_text.replace(token, "", 1)
            if generated_text.endswith(token):
                generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
        return generated_text

    def process_text_completion_response(
        self,
        model: str,
        response: httpx.Response,
        model_response: TextCompletionResponse,
        stream: bool,
        logging_obj: LiteLLMLogging,
        optional_params: dict,
        api_key: str,
        data: Union[dict, str],
        messages: list,
        print_verbose,
        encoding,
    ) -> TextCompletionResponse:
        ## LOGGING
        logging_obj.post_call(
            input=messages,
            api_key=api_key,
            original_response=response.text,
            additional_args={"complete_input_dict": data},
        )
        print_verbose(f"codestral api: raw model_response: {response.text}")
        ## RESPONSE OBJECT
        if response.status_code != 200:
            raise TextCompletionCodestralError(
                message=str(response.text),
                status_code=response.status_code,
            )
        try:
            completion_response = response.json()
        except Exception:
            raise TextCompletionCodestralError(message=response.text, status_code=422)

        _original_choices = completion_response.get("choices", [])
        _choices: List[TextChoices] = []
        for choice in _original_choices:
            # This is what 1 choice looks like from codestral API
            # {
            #     "index": 0,
            #     "message": {
            #     "role": "assistant",
            #     "content": "\n assert is_odd(1)\n assert",
            #     "tool_calls": null
            #     },
            #     "finish_reason": "length",
            #     "logprobs": null
            #     }
            _finish_reason = None
            _index = 0
            _text = None
            _logprobs = None

            _choice_message = choice.get("message", {})
            _choice = litellm.utils.TextChoices(
                finish_reason=choice.get("finish_reason"),
                index=choice.get("index"),
                text=_choice_message.get("content"),
                logprobs=choice.get("logprobs"),
            )

            _choices.append(_choice)

        _response = litellm.TextCompletionResponse(
            id=completion_response.get("id"),
            choices=_choices,
            created=completion_response.get("created"),
            model=completion_response.get("model"),
            usage=completion_response.get("usage"),
            stream=False,
            object=completion_response.get("object"),
        )
        return _response

    def completion(
        self,
        model: str,
        messages: list,
        api_base: str,
        custom_prompt_dict: dict,
        model_response: TextCompletionResponse,
        print_verbose: Callable,
        encoding,
        api_key: str,
        logging_obj,
        optional_params: dict,
        timeout: Union[float, httpx.Timeout],
        acompletion=None,
        litellm_params=None,
        logger_fn=None,
        headers: dict = {},
    ) -> Union[TextCompletionResponse, CustomStreamWrapper]:
        headers = self._validate_environment(api_key, headers)

        if optional_params.pop("custom_endpoint", None) is True:
            completion_url = api_base
        else:
            completion_url = (
                api_base or "https://codestral.mistral.ai/v1/fim/completions"
            )

        if model in custom_prompt_dict:
            # check if the model has a registered custom prompt
            model_prompt_details = custom_prompt_dict[model]
            prompt = custom_prompt(
                role_dict=model_prompt_details["roles"],
                initial_prompt_value=model_prompt_details["initial_prompt_value"],
                final_prompt_value=model_prompt_details["final_prompt_value"],
                messages=messages,
            )
        else:
            prompt = prompt_factory(model=model, messages=messages)

        ## Load Config
        config = litellm.CodestralTextCompletionConfig.get_config()
        for k, v in config.items():
            if (
                k not in optional_params
            ):  # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
                optional_params[k] = v

        stream = optional_params.pop("stream", False)

        data = {
            "model": model,
            "prompt": prompt,
            **optional_params,
        }
        input_text = prompt
        ## LOGGING
        logging_obj.pre_call(
            input=input_text,
            api_key=api_key,
            additional_args={
                "complete_input_dict": data,
                "headers": headers,
                "api_base": completion_url,
                "acompletion": acompletion,
            },
        )
        ## COMPLETION CALL
        if acompletion is True:
            ### ASYNC STREAMING
            if stream is True:
                return self.async_streaming(
                    model=model,
                    messages=messages,
                    data=data,
                    api_base=completion_url,
                    model_response=model_response,
                    print_verbose=print_verbose,
                    encoding=encoding,
                    api_key=api_key,
                    logging_obj=logging_obj,
                    optional_params=optional_params,
                    litellm_params=litellm_params,
                    logger_fn=logger_fn,
                    headers=headers,
                    timeout=timeout,
                )  # type: ignore
            else:
                ### ASYNC COMPLETION
                return self.async_completion(
                    model=model,
                    messages=messages,
                    data=data,
                    api_base=completion_url,
                    model_response=model_response,
                    print_verbose=print_verbose,
                    encoding=encoding,
                    api_key=api_key,
                    logging_obj=logging_obj,
                    optional_params=optional_params,
                    stream=False,
                    litellm_params=litellm_params,
                    logger_fn=logger_fn,
                    headers=headers,
                    timeout=timeout,
                )  # type: ignore

        ### SYNC STREAMING
        if stream is True:
            response = litellm.module_level_client.post(
                completion_url,
                headers=headers,
                data=json.dumps(data),
                stream=stream,
            )
            _response = CustomStreamWrapper(
                response.iter_lines(),
                model,
                custom_llm_provider="codestral",
                logging_obj=logging_obj,
            )
            return _response
        ### SYNC COMPLETION
        else:

            response = litellm.module_level_client.post(
                url=completion_url,
                headers=headers,
                data=json.dumps(data),
            )
        return self.process_text_completion_response(
            model=model,
            response=response,
            model_response=model_response,
            stream=optional_params.get("stream", False),
            logging_obj=logging_obj,  # type: ignore
            optional_params=optional_params,
            api_key=api_key,
            data=data,
            messages=messages,
            print_verbose=print_verbose,
            encoding=encoding,
        )

    async def async_completion(
        self,
        model: str,
        messages: list,
        api_base: str,
        model_response: TextCompletionResponse,
        print_verbose: Callable,
        encoding,
        api_key,
        logging_obj,
        stream,
        data: dict,
        optional_params: dict,
        timeout: Union[float, httpx.Timeout],
        litellm_params=None,
        logger_fn=None,
        headers={},
    ) -> TextCompletionResponse:

        async_handler = get_async_httpx_client(
            llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
            params={"timeout": timeout},
        )
        try:

            response = await async_handler.post(
                api_base, headers=headers, data=json.dumps(data)
            )
        except httpx.HTTPStatusError as e:
            raise TextCompletionCodestralError(
                status_code=e.response.status_code,
                message="HTTPStatusError - {}".format(e.response.text),
            )
        except Exception as e:
            raise TextCompletionCodestralError(
                status_code=500, message="{}".format(str(e))
            )  # don't use verbose_logger.exception, if exception is raised
        return self.process_text_completion_response(
            model=model,
            response=response,
            model_response=model_response,
            stream=stream,
            logging_obj=logging_obj,
            api_key=api_key,
            data=data,
            messages=messages,
            print_verbose=print_verbose,
            optional_params=optional_params,
            encoding=encoding,
        )

    async def async_streaming(
        self,
        model: str,
        messages: list,
        api_base: str,
        model_response: TextCompletionResponse,
        print_verbose: Callable,
        encoding,
        api_key,
        logging_obj,
        data: dict,
        timeout: Union[float, httpx.Timeout],
        optional_params=None,
        litellm_params=None,
        logger_fn=None,
        headers={},
    ) -> CustomStreamWrapper:
        data["stream"] = True

        streamwrapper = CustomStreamWrapper(
            completion_stream=None,
            make_call=partial(
                make_call,
                api_base=api_base,
                headers=headers,
                data=json.dumps(data),
                model=model,
                messages=messages,
                logging_obj=logging_obj,
            ),
            model=model,
            custom_llm_provider="text-completion-codestral",
            logging_obj=logging_obj,
        )
        return streamwrapper

    def embedding(self, *args, **kwargs):
        pass