about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py168
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/transformation.py119
2 files changed, 287 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py
new file mode 100644
index 00000000..cd8be691
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py
@@ -0,0 +1,168 @@
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    _get_httpx_client,
+    get_async_httpx_client,
+)
+from litellm.types.llms.bedrock import BedrockPreparedRequest
+from litellm.types.rerank import RerankRequest
+from litellm.types.utils import RerankResponse
+
+from ..base_aws_llm import BaseAWSLLM
+from ..common_utils import BedrockError
+from .transformation import BedrockRerankConfig
+
+if TYPE_CHECKING:
+    from botocore.awsrequest import AWSPreparedRequest
+else:
+    AWSPreparedRequest = Any
+
+
+class BedrockRerankHandler(BaseAWSLLM):
+    async def arerank(
+        self,
+        prepared_request: BedrockPreparedRequest,
+        client: Optional[AsyncHTTPHandler] = None,
+    ):
+        if client is None:
+            client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
+        try:
+            response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"])  # type: ignore
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:
+            error_code = err.response.status_code
+            raise BedrockError(status_code=error_code, message=err.response.text)
+        except httpx.TimeoutException:
+            raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+        return BedrockRerankConfig()._transform_response(response.json())
+
+    def rerank(
+        self,
+        model: str,
+        query: str,
+        documents: List[Union[str, Dict[str, Any]]],
+        optional_params: dict,
+        logging_obj: LitellmLogging,
+        top_n: Optional[int] = None,
+        rank_fields: Optional[List[str]] = None,
+        return_documents: Optional[bool] = True,
+        max_chunks_per_doc: Optional[int] = None,
+        _is_async: Optional[bool] = False,
+        api_base: Optional[str] = None,
+        extra_headers: Optional[dict] = None,
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+    ) -> RerankResponse:
+
+        request_data = RerankRequest(
+            model=model,
+            query=query,
+            documents=documents,
+            top_n=top_n,
+            rank_fields=rank_fields,
+            return_documents=return_documents,
+        )
+        data = BedrockRerankConfig()._transform_request(request_data)
+
+        prepared_request = self._prepare_request(
+            model=model,
+            optional_params=optional_params,
+            api_base=api_base,
+            extra_headers=extra_headers,
+            data=cast(dict, data),
+        )
+
+        logging_obj.pre_call(
+            input=data,
+            api_key="",
+            additional_args={
+                "complete_input_dict": data,
+                "api_base": prepared_request["endpoint_url"],
+                "headers": prepared_request["prepped"].headers,
+            },
+        )
+
+        if _is_async:
+            return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None)  # type: ignore
+
+        if client is None or not isinstance(client, HTTPHandler):
+            client = _get_httpx_client()
+        try:
+            response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"])  # type: ignore
+            response.raise_for_status()
+        except httpx.HTTPStatusError as err:
+            error_code = err.response.status_code
+            raise BedrockError(status_code=error_code, message=err.response.text)
+        except httpx.TimeoutException:
+            raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+        logging_obj.post_call(
+            original_response=response.text,
+            api_key="",
+        )
+
+        response_json = response.json()
+
+        return BedrockRerankConfig()._transform_response(response_json)
+
+    def _prepare_request(
+        self,
+        model: str,
+        api_base: Optional[str],
+        extra_headers: Optional[dict],
+        data: dict,
+        optional_params: dict,
+    ) -> BedrockPreparedRequest:
+        try:
+            from botocore.auth import SigV4Auth
+            from botocore.awsrequest import AWSRequest
+        except ImportError:
+            raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+        boto3_credentials_info = self._get_boto_credentials_from_optional_params(
+            optional_params, model
+        )
+
+        ### SET RUNTIME ENDPOINT ###
+        _, proxy_endpoint_url = self.get_runtime_endpoint(
+            api_base=api_base,
+            aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
+            aws_region_name=boto3_credentials_info.aws_region_name,
+        )
+        proxy_endpoint_url = proxy_endpoint_url.replace(
+            "bedrock-runtime", "bedrock-agent-runtime"
+        )
+        proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
+        sigv4 = SigV4Auth(
+            boto3_credentials_info.credentials,
+            "bedrock",
+            boto3_credentials_info.aws_region_name,
+        )
+        # Make POST Request
+        body = json.dumps(data).encode("utf-8")
+
+        headers = {"Content-Type": "application/json"}
+        if extra_headers is not None:
+            headers = {"Content-Type": "application/json", **extra_headers}
+        request = AWSRequest(
+            method="POST", url=proxy_endpoint_url, data=body, headers=headers
+        )
+        sigv4.add_auth(request)
+        if (
+            extra_headers is not None and "Authorization" in extra_headers
+        ):  # prevent sigv4 from overwriting the auth header
+            request.headers["Authorization"] = extra_headers["Authorization"]
+        prepped = request.prepare()
+
+        return BedrockPreparedRequest(
+            endpoint_url=proxy_endpoint_url,
+            prepped=prepped,
+            body=body,
+            data=data,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/transformation.py
new file mode 100644
index 00000000..a5380feb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/transformation.py
@@ -0,0 +1,119 @@
+"""
+Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import uuid
+from typing import List, Optional, Union
+
+from litellm.types.llms.bedrock import (
+    BedrockRerankBedrockRerankingConfiguration,
+    BedrockRerankConfiguration,
+    BedrockRerankInlineDocumentSource,
+    BedrockRerankModelConfiguration,
+    BedrockRerankQuery,
+    BedrockRerankRequest,
+    BedrockRerankSource,
+    BedrockRerankTextDocument,
+    BedrockRerankTextQuery,
+)
+from litellm.types.rerank import (
+    RerankBilledUnits,
+    RerankRequest,
+    RerankResponse,
+    RerankResponseMeta,
+    RerankResponseResult,
+    RerankTokens,
+)
+
+
+class BedrockRerankConfig:
+
+    def _transform_sources(
+        self, documents: List[Union[str, dict]]
+    ) -> List[BedrockRerankSource]:
+        """
+        Transform the sources from RerankRequest format to Bedrock format.
+        """
+        _sources = []
+        for document in documents:
+            if isinstance(document, str):
+                _sources.append(
+                    BedrockRerankSource(
+                        inlineDocumentSource=BedrockRerankInlineDocumentSource(
+                            textDocument=BedrockRerankTextDocument(text=document),
+                            type="TEXT",
+                        ),
+                        type="INLINE",
+                    )
+                )
+            else:
+                _sources.append(
+                    BedrockRerankSource(
+                        inlineDocumentSource=BedrockRerankInlineDocumentSource(
+                            jsonDocument=document, type="JSON"
+                        ),
+                        type="INLINE",
+                    )
+                )
+        return _sources
+
+    def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
+        """
+        Transform the request from RerankRequest format to Bedrock format.
+        """
+        _sources = self._transform_sources(request_data.documents)
+
+        return BedrockRerankRequest(
+            queries=[
+                BedrockRerankQuery(
+                    textQuery=BedrockRerankTextQuery(text=request_data.query),
+                    type="TEXT",
+                )
+            ],
+            rerankingConfiguration=BedrockRerankConfiguration(
+                bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
+                    modelConfiguration=BedrockRerankModelConfiguration(
+                        modelArn=request_data.model
+                    ),
+                    numberOfResults=request_data.top_n or len(request_data.documents),
+                ),
+                type="BEDROCK_RERANKING_MODEL",
+            ),
+            sources=_sources,
+        )
+
+    def _transform_response(self, response: dict) -> RerankResponse:
+        """
+        Transform the response from Bedrock into the RerankResponse format.
+
+        example input:
+        {"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
+        """
+        _billed_units = RerankBilledUnits(
+            **response.get("usage", {"search_units": 1})
+        )  # by default 1 search unit
+        _tokens = RerankTokens(**response.get("usage", {}))
+        rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
+
+        _results: Optional[List[RerankResponseResult]] = None
+
+        bedrock_results = response.get("results")
+        if bedrock_results:
+            _results = [
+                RerankResponseResult(
+                    index=result.get("index"),
+                    relevance_score=result.get("relevanceScore"),
+                )
+                for result in bedrock_results
+            ]
+
+        if _results is None:
+            raise ValueError(f"No results found in the response={response}")
+
+        return RerankResponse(
+            id=response.get("id") or str(uuid.uuid4()),
+            results=_results,
+            meta=rerank_meta,
+        )  # Return response