aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py168
1 files changed, 168 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py
new file mode 100644
index 00000000..cd8be691
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py
@@ -0,0 +1,168 @@
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.types.llms.bedrock import BedrockPreparedRequest
+from litellm.types.rerank import RerankRequest
+from litellm.types.utils import RerankResponse
+
+from ..base_aws_llm import BaseAWSLLM
+from ..common_utils import BedrockError
+from .transformation import BedrockRerankConfig
+
+if TYPE_CHECKING:
+ from botocore.awsrequest import AWSPreparedRequest
+else:
+ AWSPreparedRequest = Any
+
+
+class BedrockRerankHandler(BaseAWSLLM):
+ async def arerank(
+ self,
+ prepared_request: BedrockPreparedRequest,
+ client: Optional[AsyncHTTPHandler] = None,
+ ):
+ if client is None:
+ client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
+ try:
+ response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ return BedrockRerankConfig()._transform_response(response.json())
+
+ def rerank(
+ self,
+ model: str,
+ query: str,
+ documents: List[Union[str, Dict[str, Any]]],
+ optional_params: dict,
+ logging_obj: LitellmLogging,
+ top_n: Optional[int] = None,
+ rank_fields: Optional[List[str]] = None,
+ return_documents: Optional[bool] = True,
+ max_chunks_per_doc: Optional[int] = None,
+ _is_async: Optional[bool] = False,
+ api_base: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ) -> RerankResponse:
+
+ request_data = RerankRequest(
+ model=model,
+ query=query,
+ documents=documents,
+ top_n=top_n,
+ rank_fields=rank_fields,
+ return_documents=return_documents,
+ )
+ data = BedrockRerankConfig()._transform_request(request_data)
+
+ prepared_request = self._prepare_request(
+ model=model,
+ optional_params=optional_params,
+ api_base=api_base,
+ extra_headers=extra_headers,
+ data=cast(dict, data),
+ )
+
+ logging_obj.pre_call(
+ input=data,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": prepared_request["endpoint_url"],
+ "headers": prepared_request["prepped"].headers,
+ },
+ )
+
+ if _is_async:
+ return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
+
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client()
+ try:
+ response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ logging_obj.post_call(
+ original_response=response.text,
+ api_key="",
+ )
+
+ response_json = response.json()
+
+ return BedrockRerankConfig()._transform_response(response_json)
+
+ def _prepare_request(
+ self,
+ model: str,
+ api_base: Optional[str],
+ extra_headers: Optional[dict],
+ data: dict,
+ optional_params: dict,
+ ) -> BedrockPreparedRequest:
+ try:
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ boto3_credentials_info = self._get_boto_credentials_from_optional_params(
+ optional_params, model
+ )
+
+ ### SET RUNTIME ENDPOINT ###
+ _, proxy_endpoint_url = self.get_runtime_endpoint(
+ api_base=api_base,
+ aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
+ aws_region_name=boto3_credentials_info.aws_region_name,
+ )
+ proxy_endpoint_url = proxy_endpoint_url.replace(
+ "bedrock-runtime", "bedrock-agent-runtime"
+ )
+ proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
+ sigv4 = SigV4Auth(
+ boto3_credentials_info.credentials,
+ "bedrock",
+ boto3_credentials_info.aws_region_name,
+ )
+ # Make POST Request
+ body = json.dumps(data).encode("utf-8")
+
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=proxy_endpoint_url, data=body, headers=headers
+ )
+ sigv4.add_auth(request)
+ if (
+ extra_headers is not None and "Authorization" in extra_headers
+ ): # prevent sigv4 from overwriting the auth header
+ request.headers["Authorization"] = extra_headers["Authorization"]
+ prepped = request.prepare()
+
+ return BedrockPreparedRequest(
+ endpoint_url=proxy_endpoint_url,
+ prepped=prepped,
+ body=body,
+ data=data,
+ )