aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py116
1 files changed, 116 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py
new file mode 100644
index 00000000..1e7234ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py
@@ -0,0 +1,116 @@
+"""
+Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import uuid
+from typing import List, Optional
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.rerank import (
+ RerankBilledUnits,
+ RerankResponse,
+ RerankResponseDocument,
+ RerankResponseMeta,
+ RerankResponseResult,
+ RerankTokens,
+)
+
+from .common_utils import InfinityError
+
+
+class InfinityRerankConfig(CohereRerankConfig):
+ def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+ if api_base is None:
+ raise ValueError("api_base is required for Infinity rerank")
+ # Remove trailing slashes and ensure clean base URL
+ api_base = api_base.rstrip("/")
+ if not api_base.endswith("/rerank"):
+ api_base = f"{api_base}/rerank"
+ return api_base
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ api_key = (
+ get_secret_str("INFINITY_API_KEY")
+ or get_secret_str("INFINITY_API_KEY")
+ or litellm.infinity_key
+ )
+
+ default_headers = {
+ "Authorization": f"bearer {api_key}",
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+
+ # If 'Authorization' is provided in headers, it overrides the default.
+ if "Authorization" in headers:
+ default_headers["Authorization"] = headers["Authorization"]
+
+ # Merge other headers, overriding any default ones except Authorization
+ return {**default_headers, **headers}
+
+ def transform_rerank_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: RerankResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: dict = {},
+ optional_params: dict = {},
+ litellm_params: dict = {},
+ ) -> RerankResponse:
+ """
+ Transform Infinity rerank response
+
+ No transformation required, Infinity follows Cohere API response format
+ """
+ try:
+ raw_response_json = raw_response.json()
+ except Exception:
+ raise InfinityError(
+ message=raw_response.text, status_code=raw_response.status_code
+ )
+
+ _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {}))
+ _tokens = RerankTokens(
+ input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
+ output_tokens=(
+ raw_response_json.get("usage", {}).get("total_tokens", 0)
+ - raw_response_json.get("usage", {}).get("prompt_tokens", 0)
+ ),
+ )
+ rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
+
+ cohere_results: List[RerankResponseResult] = []
+ if raw_response_json.get("results"):
+ for result in raw_response_json.get("results"):
+ _rerank_response = RerankResponseResult(
+ index=result.get("index"),
+ relevance_score=result.get("relevance_score"),
+ )
+ if result.get("document"):
+ _rerank_response["document"] = RerankResponseDocument(
+ text=result.get("document")
+ )
+ cohere_results.append(_rerank_response)
+ if cohere_results is None:
+ raise ValueError(f"No results found in the response={raw_response_json}")
+
+ return RerankResponse(
+ id=raw_response_json.get("id") or str(uuid.uuid4()),
+ results=cohere_results,
+ meta=rerank_meta,
+ ) # Return response