diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py new file mode 100644 index 00000000..1e7234ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/infinity/rerank/transformation.py @@ -0,0 +1,116 @@ +""" +Transformation logic from Cohere's /v1/rerank format to Infinity's `/v1/rerank` format. + +Why separate file? Make it easy to see how transformation works +""" + +import uuid +from typing import List, Optional + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.cohere.rerank.transformation import CohereRerankConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.rerank import ( + RerankBilledUnits, + RerankResponse, + RerankResponseDocument, + RerankResponseMeta, + RerankResponseResult, + RerankTokens, +) + +from .common_utils import InfinityError + + +class InfinityRerankConfig(CohereRerankConfig): + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base is None: + raise ValueError("api_base is required for Infinity rerank") + # Remove trailing slashes and ensure clean base URL + api_base = api_base.rstrip("/") + if not api_base.endswith("/rerank"): + api_base = f"{api_base}/rerank" + return api_base + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("INFINITY_API_KEY") + or get_secret_str("INFINITY_API_KEY") + or litellm.infinity_key + ) + + default_headers = { + "Authorization": f"bearer {api_key}", + "accept": "application/json", + "content-type": "application/json", + } + + # If 'Authorization' is provided in headers, it overrides the default. + if "Authorization" in headers: + default_headers["Authorization"] = headers["Authorization"] + + # Merge other headers, overriding any default ones except Authorization + return {**default_headers, **headers} + + def transform_rerank_response( + self, + model: str, + raw_response: httpx.Response, + model_response: RerankResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> RerankResponse: + """ + Transform Infinity rerank response + + No transformation required, Infinity follows Cohere API response format + """ + try: + raw_response_json = raw_response.json() + except Exception: + raise InfinityError( + message=raw_response.text, status_code=raw_response.status_code + ) + + _billed_units = RerankBilledUnits(**raw_response_json.get("usage", {})) + _tokens = RerankTokens( + input_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), + output_tokens=( + raw_response_json.get("usage", {}).get("total_tokens", 0) + - raw_response_json.get("usage", {}).get("prompt_tokens", 0) + ), + ) + rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens) + + cohere_results: List[RerankResponseResult] = [] + if raw_response_json.get("results"): + for result in raw_response_json.get("results"): + _rerank_response = RerankResponseResult( + index=result.get("index"), + relevance_score=result.get("relevance_score"), + ) + if result.get("document"): + _rerank_response["document"] = RerankResponseDocument( + text=result.get("document") + ) + cohere_results.append(_rerank_response) + if cohere_results is None: + raise ValueError(f"No results found in the response={raw_response_json}") + + return RerankResponse( + id=raw_response_json.get("id") or str(uuid.uuid4()), + results=cohere_results, + meta=rerank_meta, + ) # Return response |