aboutsummaryrefslogtreecommitdiff
import asyncio
import contextvars
from functools import partial
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union

import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
from litellm.rerank_api.rerank_utils import get_optional_rerank_params
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankResponse
from litellm.types.router import *
from litellm.utils import ProviderConfigManager, client, exception_type

####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
together_rerank = TogetherAIRerank()
bedrock_rerank = BedrockRerankHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################


@client
async def arerank(
    model: str,
    query: str,
    documents: List[Union[str, Dict[str, Any]]],
    custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
    top_n: Optional[int] = None,
    rank_fields: Optional[List[str]] = None,
    return_documents: Optional[bool] = None,
    max_chunks_per_doc: Optional[int] = None,
    **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
    """
    Async: Reranks a list of documents based on their relevance to the query
    """
    try:
        loop = asyncio.get_event_loop()
        kwargs["arerank"] = True

        func = partial(
            rerank,
            model,
            query,
            documents,
            custom_llm_provider,
            top_n,
            rank_fields,
            return_documents,
            max_chunks_per_doc,
            **kwargs,
        )

        ctx = contextvars.copy_context()
        func_with_context = partial(ctx.run, func)
        init_response = await loop.run_in_executor(None, func_with_context)

        if asyncio.iscoroutine(init_response):
            response = await init_response
        else:
            response = init_response
        return response
    except Exception as e:
        raise e


@client
def rerank(  # noqa: PLR0915
    model: str,
    query: str,
    documents: List[Union[str, Dict[str, Any]]],
    custom_llm_provider: Optional[
        Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
    ] = None,
    top_n: Optional[int] = None,
    rank_fields: Optional[List[str]] = None,
    return_documents: Optional[bool] = True,
    max_chunks_per_doc: Optional[int] = None,
    max_tokens_per_doc: Optional[int] = None,
    **kwargs,
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
    """
    Reranks a list of documents based on their relevance to the query
    """
    headers: Optional[dict] = kwargs.get("headers")  # type: ignore
    litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj")  # type: ignore
    litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
    proxy_server_request = kwargs.get("proxy_server_request", None)
    model_info = kwargs.get("model_info", None)
    metadata = kwargs.get("metadata", {})
    user = kwargs.get("user", None)
    client = kwargs.get("client", None)
    try:
        _is_async = kwargs.pop("arerank", False) is True
        optional_params = GenericLiteLLMParams(**kwargs)
        # Params that are unique to specific versions of the client for the rerank call
        unique_version_params = {
            "max_chunks_per_doc": max_chunks_per_doc,
            "max_tokens_per_doc": max_tokens_per_doc,
        }
        present_version_params = [
            k for k, v in unique_version_params.items() if v is not None
        ]

        model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
            litellm.get_llm_provider(
                model=model,
                custom_llm_provider=custom_llm_provider,
                api_base=optional_params.api_base,
                api_key=optional_params.api_key,
            )
        )

        rerank_provider_config: BaseRerankConfig = (
            ProviderConfigManager.get_provider_rerank_config(
                model=model,
                provider=litellm.LlmProviders(_custom_llm_provider),
                api_base=optional_params.api_base,
                present_version_params=present_version_params,
            )
        )

        optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
            rerank_provider_config=rerank_provider_config,
            model=model,
            drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
            query=query,
            documents=documents,
            custom_llm_provider=_custom_llm_provider,
            top_n=top_n,
            rank_fields=rank_fields,
            return_documents=return_documents,
            max_chunks_per_doc=max_chunks_per_doc,
            max_tokens_per_doc=max_tokens_per_doc,
            non_default_params=kwargs,
        )

        if isinstance(optional_params.timeout, str):
            optional_params.timeout = float(optional_params.timeout)

        model_response = RerankResponse()

        litellm_logging_obj.update_environment_variables(
            model=model,
            user=user,
            optional_params=dict(optional_rerank_params),
            litellm_params={
                "litellm_call_id": litellm_call_id,
                "proxy_server_request": proxy_server_request,
                "model_info": model_info,
                "metadata": metadata,
                "preset_cache_key": None,
                "stream_response": {},
                **optional_params.model_dump(exclude_unset=True),
            },
            custom_llm_provider=_custom_llm_provider,
        )

        # Implement rerank logic here based on the custom_llm_provider
        if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy":
            # Implement Cohere rerank logic
            api_key: Optional[str] = (
                dynamic_api_key or optional_params.api_key or litellm.api_key
            )

            api_base: Optional[str] = (
                dynamic_api_base
                or optional_params.api_base
                or litellm.api_base
                or get_secret("COHERE_API_BASE")  # type: ignore
                or "https://api.cohere.com"
            )

            if api_base is None:
                raise Exception(
                    "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
                )
            response = base_llm_http_handler.rerank(
                model=model,
                custom_llm_provider=_custom_llm_provider,
                provider_config=rerank_provider_config,
                optional_rerank_params=optional_rerank_params,
                logging_obj=litellm_logging_obj,
                timeout=optional_params.timeout,
                api_key=api_key,
                api_base=api_base,
                _is_async=_is_async,
                headers=headers or litellm.headers or {},
                client=client,
                model_response=model_response,
            )
        elif _custom_llm_provider == "azure_ai":
            api_base = (
                dynamic_api_base  # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
                or optional_params.api_base
                or litellm.api_base
                or get_secret("AZURE_AI_API_BASE")  # type: ignore
            )
            response = base_llm_http_handler.rerank(
                model=model,
                custom_llm_provider=_custom_llm_provider,
                optional_rerank_params=optional_rerank_params,
                provider_config=rerank_provider_config,
                logging_obj=litellm_logging_obj,
                timeout=optional_params.timeout,
                api_key=dynamic_api_key or optional_params.api_key,
                api_base=api_base,
                _is_async=_is_async,
                headers=headers or litellm.headers or {},
                client=client,
                model_response=model_response,
            )
        elif _custom_llm_provider == "infinity":
            # Implement Infinity rerank logic
            api_key = dynamic_api_key or optional_params.api_key or litellm.api_key

            api_base = (
                dynamic_api_base
                or optional_params.api_base
                or litellm.api_base
                or get_secret_str("INFINITY_API_BASE")
            )

            if api_base is None:
                raise Exception(
                    "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
                )

            response = base_llm_http_handler.rerank(
                model=model,
                custom_llm_provider=_custom_llm_provider,
                provider_config=rerank_provider_config,
                optional_rerank_params=optional_rerank_params,
                logging_obj=litellm_logging_obj,
                timeout=optional_params.timeout,
                api_key=dynamic_api_key or optional_params.api_key,
                api_base=api_base,
                _is_async=_is_async,
                headers=headers or litellm.headers or {},
                client=client,
                model_response=model_response,
            )
        elif _custom_llm_provider == "together_ai":
            # Implement Together AI rerank logic
            api_key = (
                dynamic_api_key
                or optional_params.api_key
                or litellm.togetherai_api_key
                or get_secret("TOGETHERAI_API_KEY")  # type: ignore
                or litellm.api_key
            )

            if api_key is None:
                raise ValueError(
                    "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
                )

            response = together_rerank.rerank(
                model=model,
                query=query,
                documents=documents,
                top_n=top_n,
                rank_fields=rank_fields,
                return_documents=return_documents,
                max_chunks_per_doc=max_chunks_per_doc,
                api_key=api_key,
                _is_async=_is_async,
            )
        elif _custom_llm_provider == "jina_ai":

            if dynamic_api_key is None:
                raise ValueError(
                    "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
                )

            api_base = (
                dynamic_api_base
                or optional_params.api_base
                or litellm.api_base
                or get_secret("BEDROCK_API_BASE")  # type: ignore
            )

            response = base_llm_http_handler.rerank(
                model=model,
                custom_llm_provider=_custom_llm_provider,
                optional_rerank_params=optional_rerank_params,
                logging_obj=litellm_logging_obj,
                provider_config=rerank_provider_config,
                timeout=optional_params.timeout,
                api_key=dynamic_api_key or optional_params.api_key,
                api_base=api_base,
                _is_async=_is_async,
                headers=headers or litellm.headers or {},
                client=client,
                model_response=model_response,
            )
        elif _custom_llm_provider == "bedrock":
            api_base = (
                dynamic_api_base
                or optional_params.api_base
                or litellm.api_base
                or get_secret("BEDROCK_API_BASE")  # type: ignore
            )

            response = bedrock_rerank.rerank(
                model=model,
                query=query,
                documents=documents,
                top_n=top_n,
                rank_fields=rank_fields,
                return_documents=return_documents,
                max_chunks_per_doc=max_chunks_per_doc,
                _is_async=_is_async,
                optional_params=optional_params.model_dump(exclude_unset=True),
                api_base=api_base,
                logging_obj=litellm_logging_obj,
                client=client,
            )
        else:
            raise ValueError(f"Unsupported provider: {_custom_llm_provider}")

        # Placeholder return
        return response
    except Exception as e:
        verbose_logger.error(f"Error in rerank: {str(e)}")
        raise exception_type(
            model=model, custom_llm_provider=custom_llm_provider, original_exception=e
        )