from __future__ import annotations
import asyncio
import copy
import io
import json
import logging
from typing import Optional, Tuple, Any
import httpx
import requests
from requests.structures import CaseInsensitiveDict
from requests_toolbelt.multipart.encoder import MultipartEncoder
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
from unstructured_client._hooks.custom.form_utils import (
PARTITION_FORM_FILES_KEY,
PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
PARTITION_FORM_PAGE_RANGE_KEY,
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
FormData,
)
logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
def create_request_body(
form_data: FormData, page_content: io.BytesIO, filename: str, page_number: int
) -> MultipartEncoder:
payload = prepare_request_payload(form_data)
payload_fields: list[tuple[str, Any]] = []
for key, value in payload.items():
if isinstance(value, list):
payload_fields.extend([(key, list_value) for list_value in value])
else:
payload_fields.append((key, value))
payload_fields.append((PARTITION_FORM_FILES_KEY, (
filename,
page_content,
"application/pdf",
)))
payload_fields.append((PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, str(page_number)))
body = MultipartEncoder(
fields=payload_fields
)
return body
def create_httpx_request(
original_request: requests.Request, body: MultipartEncoder
) -> httpx.Request:
headers = prepare_request_headers(original_request.headers)
return httpx.Request(
method="POST",
url=original_request.url or "",
content=body.to_string(),
headers={**headers, "Content-Type": body.content_type},
)
def create_request(
request: requests.PreparedRequest,
body: MultipartEncoder,
) -> requests.Request:
headers = prepare_request_headers(request.headers)
return requests.Request(
method="POST",
url=request.url or "",
data=body,
headers={**headers, "Content-Type": body.content_type},
)
async def call_api_async(
client: httpx.AsyncClient,
page: Tuple[io.BytesIO, int],
original_request: requests.Request,
form_data: FormData,
filename: str,
limiter: asyncio.Semaphore,
) -> tuple[int, dict]:
page_content, page_number = page
body = create_request_body(form_data, page_content, filename, page_number)
new_request = create_httpx_request(original_request, body)
async with limiter:
try:
response = await client.send(new_request)
return response.status_code, response.json()
except Exception:
logger.error("Failed to send request for page %d", page_number)
return 500, {}
def call_api(
client: Optional[requests.Session],
page: Tuple[io.BytesIO, int],
request: requests.PreparedRequest,
form_data: FormData,
filename: str,
) -> requests.Response:
if client is None:
raise RuntimeError("HTTP client not accessible!")
page_content, page_number = page
body = create_request_body(form_data, page_content, filename, page_number)
new_request = create_request(request, body)
prepared_request = client.prepare_request(new_request)
try:
return client.send(prepared_request)
except Exception:
logger.error("Failed to send request for page %d", page_number)
return requests.Response()
def prepare_request_headers(
headers: CaseInsensitiveDict[str],
) -> CaseInsensitiveDict[str]:
"""Prepare the request headers by removing the 'Content-Type' and 'Content-Length' headers.
Args:
headers: The original request headers.
Returns:
The modified request headers.
"""
headers = copy.deepcopy(headers)
headers.pop("Content-Type", None)
headers.pop("Content-Length", None)
return headers
def prepare_request_payload(form_data: FormData) -> FormData:
"""Prepares the request payload by removing unnecessary keys and updating the file.
Args:
form_data: The original form data.
Returns:
The updated request payload.
"""
payload = copy.deepcopy(form_data)
payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
payload.pop(PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, None)
payload.pop(PARTITION_FORM_FILES_KEY, None)
payload.pop(PARTITION_FORM_PAGE_RANGE_KEY, None)
payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
updated_parameters = {
PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
}
payload.update(updated_parameters)
return payload
def create_response(response: requests.Response, elements: list) -> requests.Response:
"""
Creates a modified response object with updated content.
Args:
response: The original response object.
elements: The list of elements to be serialized and added to
the response.
Returns:
The modified response object with updated content.
"""
response_copy = copy.deepcopy(response)
content = json.dumps(elements).encode()
content_length = str(len(content))
response_copy.headers.update({"Content-Length": content_length})
setattr(response_copy, "_content", content)
return response_copy
def log_after_split_response(status_code: int, split_number: int):
if status_code == 200:
logger.info(
"Successfully partitioned set #%d, elements added to the final result.",
split_number,
)
else:
logger.warning(
"Failed to partition set #%d, its elements will be omitted in the final result.",
split_number,
)