about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py151
2 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py
new file mode 100644
index 00000000..e94f1859
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/handler.py
@@ -0,0 +1,5 @@
+"""
+Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py
new file mode 100644
index 00000000..f3624d92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/cohere/rerank/transformation.py
@@ -0,0 +1,151 @@
+from typing import Any, Dict, List, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.rerank import OptionalRerankParams, RerankRequest
+from litellm.types.utils import RerankResponse
+
+from ..common_utils import CohereError
+
+
+class CohereRerankConfig(BaseRerankConfig):
+    """
+    Reference: https://docs.cohere.com/v2/reference/rerank
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base:
+            # Remove trailing slashes and ensure clean base URL
+            api_base = api_base.rstrip("/")
+            if not api_base.endswith("/v1/rerank"):
+                api_base = f"{api_base}/v1/rerank"
+            return api_base
+        return "https://api.cohere.ai/v1/rerank"
+
+    def get_supported_cohere_rerank_params(self, model: str) -> list:
+        return [
+            "query",
+            "documents",
+            "top_n",
+            "max_chunks_per_doc",
+            "rank_fields",
+            "return_documents",
+        ]
+
+    def map_cohere_rerank_params(
+        self,
+        non_default_params: Optional[dict],
+        model: str,
+        drop_params: bool,
+        query: str,
+        documents: List[Union[str, Dict[str, Any]]],
+        custom_llm_provider: Optional[str] = None,
+        top_n: Optional[int] = None,
+        rank_fields: Optional[List[str]] = None,
+        return_documents: Optional[bool] = True,
+        max_chunks_per_doc: Optional[int] = None,
+        max_tokens_per_doc: Optional[int] = None,
+    ) -> OptionalRerankParams:
+        """
+        Map Cohere rerank params
+
+        No mapping required - returns all supported params
+        """
+        return OptionalRerankParams(
+            query=query,
+            documents=documents,
+            top_n=top_n,
+            rank_fields=rank_fields,
+            return_documents=return_documents,
+            max_chunks_per_doc=max_chunks_per_doc,
+        )
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        api_key: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            api_key = (
+                get_secret_str("COHERE_API_KEY")
+                or get_secret_str("CO_API_KEY")
+                or litellm.cohere_key
+            )
+
+        if api_key is None:
+            raise ValueError(
+                "Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
+            )
+
+        default_headers = {
+            "Authorization": f"bearer {api_key}",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+
+        # If 'Authorization' is provided in headers, it overrides the default.
+        if "Authorization" in headers:
+            default_headers["Authorization"] = headers["Authorization"]
+
+        # Merge other headers, overriding any default ones except Authorization
+        return {**default_headers, **headers}
+
+    def transform_rerank_request(
+        self,
+        model: str,
+        optional_rerank_params: OptionalRerankParams,
+        headers: dict,
+    ) -> dict:
+        if "query" not in optional_rerank_params:
+            raise ValueError("query is required for Cohere rerank")
+        if "documents" not in optional_rerank_params:
+            raise ValueError("documents is required for Cohere rerank")
+        rerank_request = RerankRequest(
+            model=model,
+            query=optional_rerank_params["query"],
+            documents=optional_rerank_params["documents"],
+            top_n=optional_rerank_params.get("top_n", None),
+            rank_fields=optional_rerank_params.get("rank_fields", None),
+            return_documents=optional_rerank_params.get("return_documents", None),
+            max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
+        )
+        return rerank_request.model_dump(exclude_none=True)
+
+    def transform_rerank_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: RerankResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> RerankResponse:
+        """
+        Transform Cohere rerank response
+
+        No transformation required, litellm follows cohere API response format
+        """
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise CohereError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        return RerankResponse(**raw_response_json)
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+    ) -> BaseLLMException:
+        return CohereError(message=error_message, status_code=status_code)
\ No newline at end of file