diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py new file mode 100644 index 00000000..cd8be691 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/rerank/handler.py @@ -0,0 +1,168 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.types.llms.bedrock import BedrockPreparedRequest +from litellm.types.rerank import RerankRequest +from litellm.types.utils import RerankResponse + +from ..base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError +from .transformation import BedrockRerankConfig + +if TYPE_CHECKING: + from botocore.awsrequest import AWSPreparedRequest +else: + AWSPreparedRequest = Any + + +class BedrockRerankHandler(BaseAWSLLM): + async def arerank( + self, + prepared_request: BedrockPreparedRequest, + client: Optional[AsyncHTTPHandler] = None, + ): + if client is None: + client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK) + try: + response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return BedrockRerankConfig()._transform_response(response.json()) + + def rerank( + self, + model: str, + query: str, + documents: List[Union[str, Dict[str, Any]]], + optional_params: dict, + logging_obj: LitellmLogging, + top_n: Optional[int] = None, + rank_fields: Optional[List[str]] = None, + return_documents: Optional[bool] = True, + max_chunks_per_doc: Optional[int] = None, + _is_async: Optional[bool] = False, + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> RerankResponse: + + request_data = RerankRequest( + model=model, + query=query, + documents=documents, + top_n=top_n, + rank_fields=rank_fields, + return_documents=return_documents, + ) + data = BedrockRerankConfig()._transform_request(request_data) + + prepared_request = self._prepare_request( + model=model, + optional_params=optional_params, + api_base=api_base, + extra_headers=extra_headers, + data=cast(dict, data), + ) + + logging_obj.pre_call( + input=data, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepared_request["endpoint_url"], + "headers": prepared_request["prepped"].headers, + }, + ) + + if _is_async: + return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore + + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client() + try: + response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + logging_obj.post_call( + original_response=response.text, + api_key="", + ) + + response_json = response.json() + + return BedrockRerankConfig()._transform_response(response_json) + + def _prepare_request( + self, + model: str, + api_base: Optional[str], + extra_headers: Optional[dict], + data: dict, + optional_params: dict, + ) -> BedrockPreparedRequest: + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + boto3_credentials_info = self._get_boto_credentials_from_optional_params( + optional_params, model + ) + + ### SET RUNTIME ENDPOINT ### + _, proxy_endpoint_url = self.get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, + aws_region_name=boto3_credentials_info.aws_region_name, + ) + proxy_endpoint_url = proxy_endpoint_url.replace( + "bedrock-runtime", "bedrock-agent-runtime" + ) + proxy_endpoint_url = f"{proxy_endpoint_url}/rerank" + sigv4 = SigV4Auth( + boto3_credentials_info.credentials, + "bedrock", + boto3_credentials_info.aws_region_name, + ) + # Make POST Request + body = json.dumps(data).encode("utf-8") + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=proxy_endpoint_url, data=body, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + return BedrockPreparedRequest( + endpoint_url=proxy_endpoint_url, + prepped=prepped, + body=body, + data=data, + ) |