diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py new file mode 100644 index 00000000..f3624d92 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.rerank import OptionalRerankParams, RerankRequest +from litellm.types.utils import RerankResponse + +from ..common_utils import CohereError + + +class CohereRerankConfig(BaseRerankConfig): + """ + Reference: https://docs.cohere.com/v2/reference/rerank + """ + + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/v1/rerank"): + api_base = f"{api_base}/v1/rerank" + return api_base + return "https://api.cohere.ai/v1/rerank" + + def get_supported_cohere_rerank_params(self, model: str) -> list: + return [ + "query", + "documents", + "top_n", + "max_chunks_per_doc", + "rank_fields", + "return_documents", + ] + + def map_cohere_rerank_params( + self, + non_default_params: Optional[dict], + model: str, + drop_params: bool, + query: str, + documents: List[Union[str, Dict[str, Any]]], + custom_llm_provider: Optional[str] = None, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + max_tokens_per_doc: Optional[int] = None, + ) -> OptionalRerankParams: + """ + Map Cohere rerank params + + No mapping required - returns all supported params + """ + return OptionalRerankParams( + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + max_chunks_per_doc=max_chunks_per_doc, + ) + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("COHERE_API_KEY") + or get_secret_str("CO_API_KEY") + or litellm.cohere_key + ) + + if api_key is None: + raise ValueError( + "Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'" + ) + + default_headers = { + "Authorization": f"bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_request( + self, + model: str, + optional_rerank_params: OptionalRerankParams, + headers: dict, + ) -> dict: + if "query" not in optional_rerank_params: + raise ValueError("query is required for Cohere rerank") + if "documents" not in optional_rerank_params: + raise ValueError("documents is required for Cohere rerank") + rerank_request = RerankRequest( + model=model, + query=optional_rerank_params["query"], + documents=optional_rerank_params["documents"], + top_n=optional_rerank_params.get("top_n", None), + rank_fields=optional_rerank_params.get("rank_fields", None), + return_documents=optional_rerank_params.get("return_documents", None), + max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None), + ) + return rerank_request.model_dump(exclude_none=True) + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Cohere rerank response + + No transformation required, litellm follows cohere API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + return RerankResponse(**raw_response_json) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(message=error_message, status_code=status_code)
\ No newline at end of file |