aboutsummaryrefslogtreecommitdiff
from __future__ import annotations

import asyncio
import copy
import io
import json
import logging
from typing import Optional, Tuple, Any

import httpx
import requests
from requests.structures import CaseInsensitiveDict
from requests_toolbelt.multipart.encoder import MultipartEncoder

from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
from unstructured_client._hooks.custom.form_utils import (
    PARTITION_FORM_FILES_KEY,
    PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
    PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
    PARTITION_FORM_PAGE_RANGE_KEY,
    PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
    FormData,
)

logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)


def create_request_body(
    form_data: FormData, page_content: io.BytesIO, filename: str, page_number: int
) -> MultipartEncoder:
    payload = prepare_request_payload(form_data)

    payload_fields:  list[tuple[str, Any]] = []
    for key, value in payload.items():
        if isinstance(value, list):
            payload_fields.extend([(key, list_value) for list_value in value])
        else:
            payload_fields.append((key, value))

    payload_fields.append((PARTITION_FORM_FILES_KEY, (
        filename,
        page_content,
        "application/pdf",
    )))

    payload_fields.append((PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, str(page_number)))

    body = MultipartEncoder(
        fields=payload_fields
    )
    return body


def create_httpx_request(
    original_request: requests.Request, body: MultipartEncoder
) -> httpx.Request:
    headers = prepare_request_headers(original_request.headers)
    return httpx.Request(
        method="POST",
        url=original_request.url or "",
        content=body.to_string(),
        headers={**headers, "Content-Type": body.content_type},
    )


def create_request(
    request: requests.PreparedRequest,
    body: MultipartEncoder,
) -> requests.Request:
    headers = prepare_request_headers(request.headers)
    return requests.Request(
        method="POST",
        url=request.url or "",
        data=body,
        headers={**headers, "Content-Type": body.content_type},
    )


async def call_api_async(
    client: httpx.AsyncClient,
    page: Tuple[io.BytesIO, int],
    original_request: requests.Request,
    form_data: FormData,
    filename: str,
    limiter: asyncio.Semaphore,
) -> tuple[int, dict]:
    page_content, page_number = page
    body = create_request_body(form_data, page_content, filename, page_number)
    new_request = create_httpx_request(original_request, body)
    async with limiter:
        try:
            response = await client.send(new_request)
            return response.status_code, response.json()
        except Exception:
            logger.error("Failed to send request for page %d", page_number)
            return 500, {}


def call_api(
    client: Optional[requests.Session],
    page: Tuple[io.BytesIO, int],
    request: requests.PreparedRequest,
    form_data: FormData,
    filename: str,
) -> requests.Response:
    if client is None:
        raise RuntimeError("HTTP client not accessible!")
    page_content, page_number = page

    body = create_request_body(form_data, page_content, filename, page_number)
    new_request = create_request(request, body)
    prepared_request = client.prepare_request(new_request)

    try:
        return client.send(prepared_request)
    except Exception:
        logger.error("Failed to send request for page %d", page_number)
        return requests.Response()


def prepare_request_headers(
    headers: CaseInsensitiveDict[str],
) -> CaseInsensitiveDict[str]:
    """Prepare the request headers by removing the 'Content-Type' and 'Content-Length' headers.

    Args:
        headers: The original request headers.

    Returns:
        The modified request headers.
    """
    headers = copy.deepcopy(headers)
    headers.pop("Content-Type", None)
    headers.pop("Content-Length", None)
    return headers


def prepare_request_payload(form_data: FormData) -> FormData:
    """Prepares the request payload by removing unnecessary keys and updating the file.

    Args:
        form_data: The original form data.

    Returns:
        The updated request payload.
    """
    payload = copy.deepcopy(form_data)
    payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
    payload.pop(PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, None)
    payload.pop(PARTITION_FORM_FILES_KEY, None)
    payload.pop(PARTITION_FORM_PAGE_RANGE_KEY, None)
    payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
    updated_parameters = {
        PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
    }
    payload.update(updated_parameters)
    return payload


def create_response(response: requests.Response, elements: list) -> requests.Response:
    """
    Creates a modified response object with updated content.

    Args:
        response: The original response object.
        elements: The list of elements to be serialized and added to
        the response.

    Returns:
        The modified response object with updated content.
    """
    response_copy = copy.deepcopy(response)
    content = json.dumps(elements).encode()
    content_length = str(len(content))
    response_copy.headers.update({"Content-Length": content_length})
    setattr(response_copy, "_content", content)
    return response_copy


def log_after_split_response(status_code: int, split_number: int):
    if status_code == 200:
        logger.info(
            "Successfully partitioned set #%d, elements added to the final result.",
            split_number,
        )
    else:
        logger.warning(
            "Failed to partition set #%d, its elements will be omitted in the final result.",
            split_number,
        )