about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/rerank_api
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/rerank_api
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/rerank_api')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py333
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/rerank_api/rerank_utils.py46
2 files changed, 379 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py b/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py
new file mode 100644
index 00000000..ce8ae21c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py
@@ -0,0 +1,333 @@
+import asyncio
+import contextvars
+from functools import partial
+from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
+from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler
+from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
+from litellm.llms.together_ai.rerank.handler import TogetherAIRerank
+from litellm.rerank_api.rerank_utils import get_optional_rerank_params
+from litellm.secret_managers.main import get_secret, get_secret_str
+from litellm.types.rerank import OptionalRerankParams, RerankResponse
+from litellm.types.router import *
+from litellm.utils import ProviderConfigManager, client, exception_type
+
+####### ENVIRONMENT VARIABLES ###################
+# Initialize any necessary instances or variables here
+together_rerank = TogetherAIRerank()
+bedrock_rerank = BedrockRerankHandler()
+base_llm_http_handler = BaseLLMHTTPHandler()
+#################################################
+
+
+@client
+async def arerank(
+    model: str,
+    query: str,
+    documents: List[Union[str, Dict[str, Any]]],
+    custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
+    top_n: Optional[int] = None,
+    rank_fields: Optional[List[str]] = None,
+    return_documents: Optional[bool] = None,
+    max_chunks_per_doc: Optional[int] = None,
+    **kwargs,
+) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
+    """
+    Async: Reranks a list of documents based on their relevance to the query
+    """
+    try:
+        loop = asyncio.get_event_loop()
+        kwargs["arerank"] = True
+
+        func = partial(
+            rerank,
+            model,
+            query,
+            documents,
+            custom_llm_provider,
+            top_n,
+            rank_fields,
+            return_documents,
+            max_chunks_per_doc,
+            **kwargs,
+        )
+
+        ctx = contextvars.copy_context()
+        func_with_context = partial(ctx.run, func)
+        init_response = await loop.run_in_executor(None, func_with_context)
+
+        if asyncio.iscoroutine(init_response):
+            response = await init_response
+        else:
+            response = init_response
+        return response
+    except Exception as e:
+        raise e
+
+
+@client
+def rerank(  # noqa: PLR0915
+    model: str,
+    query: str,
+    documents: List[Union[str, Dict[str, Any]]],
+    custom_llm_provider: Optional[
+        Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"]
+    ] = None,
+    top_n: Optional[int] = None,
+    rank_fields: Optional[List[str]] = None,
+    return_documents: Optional[bool] = True,
+    max_chunks_per_doc: Optional[int] = None,
+    max_tokens_per_doc: Optional[int] = None,
+    **kwargs,
+) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
+    """
+    Reranks a list of documents based on their relevance to the query
+    """
+    headers: Optional[dict] = kwargs.get("headers")  # type: ignore
+    litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj")  # type: ignore
+    litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
+    proxy_server_request = kwargs.get("proxy_server_request", None)
+    model_info = kwargs.get("model_info", None)
+    metadata = kwargs.get("metadata", {})
+    user = kwargs.get("user", None)
+    client = kwargs.get("client", None)
+    try:
+        _is_async = kwargs.pop("arerank", False) is True
+        optional_params = GenericLiteLLMParams(**kwargs)
+        # Params that are unique to specific versions of the client for the rerank call
+        unique_version_params = {
+            "max_chunks_per_doc": max_chunks_per_doc,
+            "max_tokens_per_doc": max_tokens_per_doc,
+        }
+        present_version_params = [
+            k for k, v in unique_version_params.items() if v is not None
+        ]
+
+        model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
+            litellm.get_llm_provider(
+                model=model,
+                custom_llm_provider=custom_llm_provider,
+                api_base=optional_params.api_base,
+                api_key=optional_params.api_key,
+            )
+        )
+
+        rerank_provider_config: BaseRerankConfig = (
+            ProviderConfigManager.get_provider_rerank_config(
+                model=model,
+                provider=litellm.LlmProviders(_custom_llm_provider),
+                api_base=optional_params.api_base,
+                present_version_params=present_version_params,
+            )
+        )
+
+        optional_rerank_params: OptionalRerankParams = get_optional_rerank_params(
+            rerank_provider_config=rerank_provider_config,
+            model=model,
+            drop_params=kwargs.get("drop_params") or litellm.drop_params or False,
+            query=query,
+            documents=documents,
+            custom_llm_provider=_custom_llm_provider,
+            top_n=top_n,
+            rank_fields=rank_fields,
+            return_documents=return_documents,
+            max_chunks_per_doc=max_chunks_per_doc,
+            max_tokens_per_doc=max_tokens_per_doc,
+            non_default_params=kwargs,
+        )
+
+        if isinstance(optional_params.timeout, str):
+            optional_params.timeout = float(optional_params.timeout)
+
+        model_response = RerankResponse()
+
+        litellm_logging_obj.update_environment_variables(
+            model=model,
+            user=user,
+            optional_params=dict(optional_rerank_params),
+            litellm_params={
+                "litellm_call_id": litellm_call_id,
+                "proxy_server_request": proxy_server_request,
+                "model_info": model_info,
+                "metadata": metadata,
+                "preset_cache_key": None,
+                "stream_response": {},
+                **optional_params.model_dump(exclude_unset=True),
+            },
+            custom_llm_provider=_custom_llm_provider,
+        )
+
+        # Implement rerank logic here based on the custom_llm_provider
+        if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy":
+            # Implement Cohere rerank logic
+            api_key: Optional[str] = (
+                dynamic_api_key or optional_params.api_key or litellm.api_key
+            )
+
+            api_base: Optional[str] = (
+                dynamic_api_base
+                or optional_params.api_base
+                or litellm.api_base
+                or get_secret("COHERE_API_BASE")  # type: ignore
+                or "https://api.cohere.com"
+            )
+
+            if api_base is None:
+                raise Exception(
+                    "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
+                )
+            response = base_llm_http_handler.rerank(
+                model=model,
+                custom_llm_provider=_custom_llm_provider,
+                provider_config=rerank_provider_config,
+                optional_rerank_params=optional_rerank_params,
+                logging_obj=litellm_logging_obj,
+                timeout=optional_params.timeout,
+                api_key=api_key,
+                api_base=api_base,
+                _is_async=_is_async,
+                headers=headers or litellm.headers or {},
+                client=client,
+                model_response=model_response,
+            )
+        elif _custom_llm_provider == "azure_ai":
+            api_base = (
+                dynamic_api_base  # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
+                or optional_params.api_base
+                or litellm.api_base
+                or get_secret("AZURE_AI_API_BASE")  # type: ignore
+            )
+            response = base_llm_http_handler.rerank(
+                model=model,
+                custom_llm_provider=_custom_llm_provider,
+                optional_rerank_params=optional_rerank_params,
+                provider_config=rerank_provider_config,
+                logging_obj=litellm_logging_obj,
+                timeout=optional_params.timeout,
+                api_key=dynamic_api_key or optional_params.api_key,
+                api_base=api_base,
+                _is_async=_is_async,
+                headers=headers or litellm.headers or {},
+                client=client,
+                model_response=model_response,
+            )
+        elif _custom_llm_provider == "infinity":
+            # Implement Infinity rerank logic
+            api_key = dynamic_api_key or optional_params.api_key or litellm.api_key
+
+            api_base = (
+                dynamic_api_base
+                or optional_params.api_base
+                or litellm.api_base
+                or get_secret_str("INFINITY_API_BASE")
+            )
+
+            if api_base is None:
+                raise Exception(
+                    "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var."
+                )
+
+            response = base_llm_http_handler.rerank(
+                model=model,
+                custom_llm_provider=_custom_llm_provider,
+                provider_config=rerank_provider_config,
+                optional_rerank_params=optional_rerank_params,
+                logging_obj=litellm_logging_obj,
+                timeout=optional_params.timeout,
+                api_key=dynamic_api_key or optional_params.api_key,
+                api_base=api_base,
+                _is_async=_is_async,
+                headers=headers or litellm.headers or {},
+                client=client,
+                model_response=model_response,
+            )
+        elif _custom_llm_provider == "together_ai":
+            # Implement Together AI rerank logic
+            api_key = (
+                dynamic_api_key
+                or optional_params.api_key
+                or litellm.togetherai_api_key
+                or get_secret("TOGETHERAI_API_KEY")  # type: ignore
+                or litellm.api_key
+            )
+
+            if api_key is None:
+                raise ValueError(
+                    "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
+                )
+
+            response = together_rerank.rerank(
+                model=model,
+                query=query,
+                documents=documents,
+                top_n=top_n,
+                rank_fields=rank_fields,
+                return_documents=return_documents,
+                max_chunks_per_doc=max_chunks_per_doc,
+                api_key=api_key,
+                _is_async=_is_async,
+            )
+        elif _custom_llm_provider == "jina_ai":
+
+            if dynamic_api_key is None:
+                raise ValueError(
+                    "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment"
+                )
+
+            api_base = (
+                dynamic_api_base
+                or optional_params.api_base
+                or litellm.api_base
+                or get_secret("BEDROCK_API_BASE")  # type: ignore
+            )
+
+            response = base_llm_http_handler.rerank(
+                model=model,
+                custom_llm_provider=_custom_llm_provider,
+                optional_rerank_params=optional_rerank_params,
+                logging_obj=litellm_logging_obj,
+                provider_config=rerank_provider_config,
+                timeout=optional_params.timeout,
+                api_key=dynamic_api_key or optional_params.api_key,
+                api_base=api_base,
+                _is_async=_is_async,
+                headers=headers or litellm.headers or {},
+                client=client,
+                model_response=model_response,
+            )
+        elif _custom_llm_provider == "bedrock":
+            api_base = (
+                dynamic_api_base
+                or optional_params.api_base
+                or litellm.api_base
+                or get_secret("BEDROCK_API_BASE")  # type: ignore
+            )
+
+            response = bedrock_rerank.rerank(
+                model=model,
+                query=query,
+                documents=documents,
+                top_n=top_n,
+                rank_fields=rank_fields,
+                return_documents=return_documents,
+                max_chunks_per_doc=max_chunks_per_doc,
+                _is_async=_is_async,
+                optional_params=optional_params.model_dump(exclude_unset=True),
+                api_base=api_base,
+                logging_obj=litellm_logging_obj,
+                client=client,
+            )
+        else:
+            raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
+
+        # Placeholder return
+        return response
+    except Exception as e:
+        verbose_logger.error(f"Error in rerank: {str(e)}")
+        raise exception_type(
+            model=model, custom_llm_provider=custom_llm_provider, original_exception=e
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/rerank_api/rerank_utils.py b/.venv/lib/python3.12/site-packages/litellm/rerank_api/rerank_utils.py
new file mode 100644
index 00000000..f70ec015
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/rerank_api/rerank_utils.py
@@ -0,0 +1,46 @@
+from typing import Any, Dict, List, Optional, Union
+
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
+from litellm.types.rerank import OptionalRerankParams
+
+
+def get_optional_rerank_params(
+    rerank_provider_config: BaseRerankConfig,
+    model: str,
+    drop_params: bool,
+    query: str,
+    documents: List[Union[str, Dict[str, Any]]],
+    custom_llm_provider: Optional[str] = None,
+    top_n: Optional[int] = None,
+    rank_fields: Optional[List[str]] = None,
+    return_documents: Optional[bool] = True,
+    max_chunks_per_doc: Optional[int] = None,
+    max_tokens_per_doc: Optional[int] = None,
+    non_default_params: Optional[dict] = None,
+) -> OptionalRerankParams:
+    all_non_default_params = non_default_params or {}
+    if query is not None:
+        all_non_default_params["query"] = query
+    if top_n is not None:
+        all_non_default_params["top_n"] = top_n
+    if documents is not None:
+        all_non_default_params["documents"] = documents
+    if return_documents is not None:
+        all_non_default_params["return_documents"] = return_documents
+    if max_chunks_per_doc is not None:
+        all_non_default_params["max_chunks_per_doc"] = max_chunks_per_doc
+    if max_tokens_per_doc is not None:
+        all_non_default_params["max_tokens_per_doc"] = max_tokens_per_doc
+    return rerank_provider_config.map_cohere_rerank_params(
+        model=model,
+        drop_params=drop_params,
+        query=query,
+        documents=documents,
+        custom_llm_provider=custom_llm_provider,
+        top_n=top_n,
+        rank_fields=rank_fields,
+        return_documents=return_documents,
+        max_chunks_per_doc=max_chunks_per_doc,
+        max_tokens_per_doc=max_tokens_per_doc,
+        non_default_params=all_non_default_params,
+    )