about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py90
2 files changed, 95 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py
new file mode 100644
index 00000000..57e7cefd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py
@@ -0,0 +1,5 @@
+"""
+Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
new file mode 100644
index 00000000..842511f3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
@@ -0,0 +1,90 @@
+"""
+Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. 
+"""
+
+from typing import Optional
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.utils import RerankResponse
+
+
+class AzureAIRerankConfig(CohereRerankConfig):
+    """
+    Azure AI Rerank - Follows the same Spec as Cohere Rerank
+    """
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base is None:
+            raise ValueError(
+                "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
+            )
+        if not api_base.endswith("/v1/rerank"):
+            api_base = f"{api_base}/v1/rerank"
+        return api_base
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        api_key: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
+
+        if api_key is None:
+            raise ValueError(
+                "Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
+            )
+
+        default_headers = {
+            "Authorization": f"Bearer {api_key}",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+
+        # If 'Authorization' is provided in headers, it overrides the default.
+        if "Authorization" in headers:
+            default_headers["Authorization"] = headers["Authorization"]
+
+        # Merge other headers, overriding any default ones except Authorization
+        return {**default_headers, **headers}
+
+    def transform_rerank_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: RerankResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> RerankResponse:
+        rerank_response = super().transform_rerank_response(
+            model=model,
+            raw_response=raw_response,
+            model_response=model_response,
+            logging_obj=logging_obj,
+            api_key=api_key,
+            request_data=request_data,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+        )
+        base_model = self._get_base_model(
+            rerank_response._hidden_params.get("llm_provider-azureml-model-group")
+        )
+        rerank_response._hidden_params["model"] = base_model
+        return rerank_response
+
+    def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
+        if azure_model_group is None:
+            return None
+        if azure_model_group == "offer-cohere-rerank-mul-paygo":
+            return "azure_ai/cohere-rerank-v3-multilingual"
+        if azure_model_group == "offer-cohere-rerank-eng-paygo":
+            return "azure_ai/cohere-rerank-v3-english"
+        return azure_model_group