aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
new file mode 100644
index 00000000..842511f3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
@@ -0,0 +1,90 @@
+"""
+Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
+"""
+
+from typing import Optional
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.utils import RerankResponse
+
+
+class AzureAIRerankConfig(CohereRerankConfig):
+ """
+ Azure AI Rerank - Follows the same Spec as Cohere Rerank
+ """
+ def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+ if api_base is None:
+ raise ValueError(
+ "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
+ )
+ if not api_base.endswith("/v1/rerank"):
+ api_base = f"{api_base}/v1/rerank"
+ return api_base
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ api_key: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
+
+ if api_key is None:
+ raise ValueError(
+ "Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
+ )
+
+ default_headers = {
+ "Authorization": f"Bearer {api_key}",
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+
+ # If 'Authorization' is provided in headers, it overrides the default.
+ if "Authorization" in headers:
+ default_headers["Authorization"] = headers["Authorization"]
+
+ # Merge other headers, overriding any default ones except Authorization
+ return {**default_headers, **headers}
+
+ def transform_rerank_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: RerankResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ request_data: dict = {},
+ optional_params: dict = {},
+ litellm_params: dict = {},
+ ) -> RerankResponse:
+ rerank_response = super().transform_rerank_response(
+ model=model,
+ raw_response=raw_response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=request_data,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+ base_model = self._get_base_model(
+ rerank_response._hidden_params.get("llm_provider-azureml-model-group")
+ )
+ rerank_response._hidden_params["model"] = base_model
+ return rerank_response
+
+ def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
+ if azure_model_group is None:
+ return None
+ if azure_model_group == "offer-cohere-rerank-mul-paygo":
+ return "azure_ai/cohere-rerank-v3-multilingual"
+ if azure_model_group == "offer-cohere-rerank-eng-paygo":
+ return "azure_ai/cohere-rerank-v3-english"
+ return azure_model_group