about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/common_utils.py19
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py116
3 files changed, 140 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/common_utils.py
new file mode 100644
index 00000000..99477d1a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/common_utils.py
@@ -0,0 +1,19 @@
+import httpx
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+
+
+class InfinityError(BaseLLMException):
+    def __init__(self, status_code, message):
+        self.status_code = status_code
+        self.message = message
+        self.request = httpx.Request(
+            method="POST", url="https://github.com/michaelfeil/infinity"
+        )
+        self.response = httpx.Response(status_code=status_code, request=self.request)
+        super().__init__(
+            status_code=status_code,
+            message=message,
+            request=self.request,
+            response=self.response,
+        )  # Call the base class constructor with the parameters it needs
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/handler.py
new file mode 100644
index 00000000..5b8a2c0c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/handler.py
@@ -0,0 +1,5 @@
+"""
+Infinity Rerank - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py
new file mode 100644
index 00000000..1e7234ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py
@@ -0,0 +1,116 @@
+"""
+Transformation logic from Cohere's /v1/rerank format to Infinity's  `/v1/rerank` format. 
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import uuid
+from typing import List, Optional
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.rerank import (
+    RerankBilledUnits,
+    RerankResponse,
+    RerankResponseDocument,
+    RerankResponseMeta,
+    RerankResponseResult,
+    RerankTokens,
+)
+
+from .common_utils import InfinityError
+
+
+class InfinityRerankConfig(CohereRerankConfig):
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base is None:
+            raise ValueError("api_base is required for Infinity rerank")
+        # Remove trailing slashes and ensure clean base URL
+        api_base = api_base.rstrip("/")
+        if not api_base.endswith("/rerank"):
+            api_base = f"{api_base}/rerank"
+        return api_base
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        api_key: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            api_key = (
+                get_secret_str("INFINITY_API_KEY")
+                or get_secret_str("INFINITY_API_KEY")
+                or litellm.infinity_key
+            )
+
+        default_headers = {
+            "Authorization": f"bearer {api_key}",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+
+        # If 'Authorization' is provided in headers, it overrides the default.
+        if "Authorization" in headers:
+            default_headers["Authorization"] = headers["Authorization"]
+
+        # Merge other headers, overriding any default ones except Authorization
+        return {**default_headers, **headers}
+
+    def transform_rerank_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: RerankResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> RerankResponse:
+        """
+        Transform Infinity rerank response
+
+        No transformation required, Infinity follows Cohere API response format
+        """
+        try:
+            raw_response_json = raw_response.json()
+        except Exception:
+            raise InfinityError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+
+        _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
+        _tokens = RerankTokens(
+            input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
+            output_tokens=(
+                raw_response_json.get("usage", {}).get("total_tokens", 0)
+                - raw_response_json.get("usage", {}).get("prompt_tokens", 0)
+            ),
+        )
+        rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
+
+        cohere_results: List[RerankResponseResult] = []
+        if raw_response_json.get("results"):
+            for result in raw_response_json.get("results"):
+                _rerank_response = RerankResponseResult(
+                    index=result.get("index"),
+                    relevance_score=result.get("relevance_score"),
+                )
+                if result.get("document"):
+                    _rerank_response["document"] = RerankResponseDocument(
+                        text=result.get("document")
+                    )
+                cohere_results.append(_rerank_response)
+        if cohere_results is None:
+            raise ValueError(f"No results found in the response={raw_response_json}")
+
+        return RerankResponse(
+            id=raw_response_json.get("id") or str(uuid.uuid4()),
+            results=cohere_results,
+            meta=rerank_meta,
+        )  # Return response