diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/rerank_api/main.py | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py b/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py new file mode 100644 index 00000000..ce8ae21c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/rerank_api/main.py @@ -0,0 +1,333 @@ +import asyncio +import contextvars +from functools import partial +from typing import Any, Coroutine, Dict, List, Literal, Optional, Union + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.llms.bedrock.rerank.handler import BedrockRerankHandler +from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler +from litellm.llms.together_ai.rerank.handler import TogetherAIRerank +from litellm.rerank_api.rerank_utils import get_optional_rerank_params +from litellm.secret_managers.main import get_secret, get_secret_str +from litellm.types.rerank import OptionalRerankParams, RerankResponse +from litellm.types.router import * +from litellm.utils import ProviderConfigManager, client, exception_type + +####### ENVIRONMENT VARIABLES ################### +# Initialize any necessary instances or variables here +together_rerank = TogetherAIRerank() +bedrock_rerank = BedrockRerankHandler() +base_llm_http_handler = BaseLLMHTTPHandler() +################################################# + + +@client +async def arerank( + model: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = None, + max_chunks_per_doc: Optional[int] = None, + **kwargs, +) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: + """ + Async: Reranks a list of documents based on their relevance to the query + """ + try: + loop = asyncio.get_event_loop() + kwargs["arerank"] = True + + func = partial( + rerank, + model, + query, + documents, + custom_llm_provider, + top_n, + rank_fields, + return_documents, + max_chunks_per_doc, + **kwargs, + ) + + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + return response + except Exception as e: + raise e + + +@client +def rerank( # noqa: PLR0915 + model: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[ + Literal["cohere", "together_ai", "azure_ai", "infinity", "litellm_proxy"] + ] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + **kwargs, +) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]: + """ + Reranks a list of documents based on their relevance to the query + """ + headers: Optional[dict] = kwargs.get("headers") # type: ignore + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore + litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + user = kwargs.get("user", None) + client = kwargs.get("client", None) + try: + _is_async = kwargs.pop("arerank", False) is True + optional_params = GenericLiteLLMParams(**kwargs) + # Params that are unique to specific versions of the client for the rerank call + unique_version_params = { + "max_chunks_per_doc": max_chunks_per_doc, + "max_tokens_per_doc": max_tokens_per_doc, + } + present_version_params = [ + k for k, v in unique_version_params.items() if v is not None + ] + + model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( + litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, + ) + ) + + rerank_provider_config: BaseRerankConfig = ( + ProviderConfigManager.get_provider_rerank_config( + model=model, + provider=litellm.LlmProviders(_custom_llm_provider), + api_base=optional_params.api_base, + present_version_params=present_version_params, + ) + ) + + optional_rerank_params: OptionalRerankParams = get_optional_rerank_params( + rerank_provider_config=rerank_provider_config, + model=model, + drop_params=kwargs.get("drop_params") or litellm.drop_params or False, + query=query, + documents=documents, + custom_llm_provider=_custom_llm_provider, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + max_tokens_per_doc=max_tokens_per_doc, + non_default_params=kwargs, + ) + + if isinstance(optional_params.timeout, str): + optional_params.timeout = float(optional_params.timeout) + + model_response = RerankResponse() + + litellm_logging_obj.update_environment_variables( + model=model, + user=user, + optional_params=dict(optional_rerank_params), + litellm_params={ + "litellm_call_id": litellm_call_id, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + **optional_params.model_dump(exclude_unset=True), + }, + custom_llm_provider=_custom_llm_provider, + ) + + # Implement rerank logic here based on the custom_llm_provider + if _custom_llm_provider == "cohere" or _custom_llm_provider == "litellm_proxy": + # Implement Cohere rerank logic + api_key: Optional[str] = ( + dynamic_api_key or optional_params.api_key or litellm.api_key + ) + + api_base: Optional[str] = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") # type: ignore + or "https://api.cohere.com" + ) + + if api_base is None: + raise Exception( + "Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var." + ) + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + provider_config=rerank_provider_config, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == "azure_ai": + api_base = ( + dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or optional_params.api_base + or litellm.api_base + or get_secret("AZURE_AI_API_BASE") # type: ignore + ) + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + provider_config=rerank_provider_config, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == "infinity": + # Implement Infinity rerank logic + api_key = dynamic_api_key or optional_params.api_key or litellm.api_key + + api_base = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret_str("INFINITY_API_BASE") + ) + + if api_base is None: + raise Exception( + "Invalid api base. api_base=None. Set in call or via `INFINITY_API_BASE` env var." + ) + + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + provider_config=rerank_provider_config, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == "together_ai": + # Implement Together AI rerank logic + api_key = ( + dynamic_api_key + or optional_params.api_key + or litellm.togetherai_api_key + or get_secret("TOGETHERAI_API_KEY") # type: ignore + or litellm.api_key + ) + + if api_key is None: + raise ValueError( + "TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment" + ) + + response = together_rerank.rerank( + model=model, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + api_key=api_key, + _is_async=_is_async, + ) + elif _custom_llm_provider == "jina_ai": + + if dynamic_api_key is None: + raise ValueError( + "Jina AI API key is required, please set 'JINA_AI_API_KEY' in your environment" + ) + + api_base = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("BEDROCK_API_BASE") # type: ignore + ) + + response = base_llm_http_handler.rerank( + model=model, + custom_llm_provider=_custom_llm_provider, + optional_rerank_params=optional_rerank_params, + logging_obj=litellm_logging_obj, + provider_config=rerank_provider_config, + timeout=optional_params.timeout, + api_key=dynamic_api_key or optional_params.api_key, + api_base=api_base, + _is_async=_is_async, + headers=headers or litellm.headers or {}, + client=client, + model_response=model_response, + ) + elif _custom_llm_provider == "bedrock": + api_base = ( + dynamic_api_base + or optional_params.api_base + or litellm.api_base + or get_secret("BEDROCK_API_BASE") # type: ignore + ) + + response = bedrock_rerank.rerank( + model=model, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + _is_async=_is_async, + optional_params=optional_params.model_dump(exclude_unset=True), + api_base=api_base, + logging_obj=litellm_logging_obj, + client=client, + ) + else: + raise ValueError(f"Unsupported provider: {_custom_llm_provider}") + + # Placeholder return + return response + except Exception as e: + verbose_logger.error(f"Error in rerank: {str(e)}") + raise exception_type( + model=model, custom_llm_provider=custom_llm_provider, original_exception=e + ) |