aboutsummaryrefslogtreecommitdiff
from __future__ import annotations

import logging
from typing import Union

from requests_toolbelt.multipart.decoder import MultipartDecoder

from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
from unstructured_client.models import shared

logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
FormData = dict[str, Union[str, shared.Files, list[str]]]

PARTITION_FORM_FILES_KEY = "files"
PARTITION_FORM_SPLIT_PDF_PAGE_KEY = "split_pdf_page"
PARTITION_FORM_PAGE_RANGE_KEY = "split_pdf_page_range[]"
PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY = "split_pdf_allow_failed"
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY = "starting_page_number"
PARTITION_FORM_CONCURRENCY_LEVEL_KEY = "split_pdf_concurrency_level"


def get_page_range(form_data: FormData, key: str, max_pages: int) -> tuple[int, int]:
    """Retrieves the split page range from the given form data.

    If the range is invalid or outside the bounds of the page count,
    returns (1, num_pages), i.e. the full range.

    Args:
        form_data: The form data containing the page range
        key: The key to look for in the form data.

    Returns:
        The range of pages to send in the request in the form (start, end)
    """
    try:
        _page_range = form_data.get(key)

        if _page_range is not None:
            page_range = (int(_page_range[0]), int(_page_range[1]))
        else:
            page_range = (1, max_pages)

    except (ValueError, IndexError) as exc:
        msg = f"{_page_range} is not a valid page range."
        logger.error(msg)
        raise ValueError(msg) from exc

    start, end = page_range

    if not 0 < start <= max_pages or not 0 < end <= max_pages or not start <= end:
        msg = f"Page range {page_range} is out of bounds. Start and end values should be between 1 and {max_pages}."
        logger.error(msg)
        raise ValueError(msg)

    return page_range


def get_starting_page_number(form_data: FormData, key: str, fallback_value: int) -> int:
    """Retrieves the starting page number from the given form data.

    In case given starting page number is not a valid integer or less than 1, it will
    use the default value.

    Args:
        form_data: The form data containing the starting page number.
        key: The key to look for in the form data.
        fallback_value: The default value to use in case of an error.

    Returns:
        The starting page number.
    """
    starting_page_number = fallback_value
    try:
        _starting_page_number = form_data.get(key) or fallback_value
        starting_page_number = int(_starting_page_number)  # type: ignore
    except ValueError:
        logger.warning(
            "'%s' is not a valid integer. Using default value '%d'.",
            key,
            fallback_value,
        )

    if starting_page_number < 1:
        logger.warning(
            "'%s' is less than 1. Using default value '%d'.",
            key,
            fallback_value,
        )
        starting_page_number = fallback_value

    return starting_page_number

def get_split_pdf_allow_failed_param(
    form_data: FormData, key: str, fallback_value: bool,
) -> bool:
    """Retrieves the value for allow failed that should be used for splitting pdf.

    In case given the number is not a "false" or "true" literal, it will use the
    default value.

    Args:
        form_data: The form data containing the desired concurrency level.
        key: The key to look for in the form data.
        fallback_value: The default value to use in case of an error.

    Returns:
        The concurrency level after validation.
    """
    allow_failed = form_data.get(key)

    if allow_failed is None:
        return fallback_value

    if allow_failed.lower() not in ["true", "false"]:
        logger.warning(
            "'%s' is not a valid boolean. Using default value '%s'.",
            key,
            fallback_value,
        )
        return fallback_value

    return allow_failed.lower() == "true"

def get_split_pdf_concurrency_level_param(
    form_data: FormData, key: str, fallback_value: int, max_allowed: int
) -> int:
    """Retrieves the value for concurreny level that should be used for splitting pdf.

    In case given the number is not a valid integer or less than 1, it will use the
    default value.

    Args:
        form_data: The form data containing the desired concurrency level.
        key: The key to look for in the form data.
        fallback_value: The default value to use in case of an error.
        max_allowed: The maximum allowed value for the concurrency level.

    Returns:
        The concurrency level after validation.
    """
    concurrency_level_str = form_data.get(key)

    if concurrency_level_str is None:
        return fallback_value

    try:
        concurrency_level = int(concurrency_level_str)
    except ValueError:
        logger.warning(
            "'%s' is not a valid integer. Using default value '%s'.",
            key,
            fallback_value,
        )
        return fallback_value

    if concurrency_level < 1:
        logger.warning(
            "'%s' is less than 1. Using the default value = %s.",
            key,
            fallback_value,
        )
        return fallback_value

    if concurrency_level > max_allowed:
        logger.warning(
            "'%s' is greater than %s. Using the maximum allowed value = %s.",
            key,
            max_allowed,
            max_allowed,
        )
        return max_allowed

    return concurrency_level


def decode_content_disposition(content_disposition: bytes) -> dict[str, str]:
    """Decode the `Content-Disposition` header and return the parameters as a dictionary.

    Args:
        content_disposition: The `Content-Disposition` header as bytes.

    Returns:
        A dictionary containing the parameters extracted from the
        `Content-Disposition` header.
    """
    data = content_disposition.decode().split("; ")[1:]
    parameters = [d.split("=") for d in data]
    parameters_dict = {p[0]: p[1].strip('"') for p in parameters}
    return parameters_dict


def parse_form_data(decoded_data: MultipartDecoder) -> FormData:
    """Parses the form data from the decoded multipart data.

    Args:
        decoded_data: The decoded multipart data.

    Returns:
        The parsed form data.
    """
    form_data: FormData = {}

    for part in decoded_data.parts:
        content_disposition = part.headers.get(b"Content-Disposition")
        if content_disposition is None:
            raise RuntimeError("Content-Disposition header not found. Can't split pdf file.")
        part_params = decode_content_disposition(content_disposition)
        name = part_params.get("name")

        if name is None:
            continue

        if name == PARTITION_FORM_FILES_KEY:
            filename = part_params.get("filename")
            if filename is None or not filename.strip():
                raise ValueError("Filename can't be an empty string.")
            form_data[PARTITION_FORM_FILES_KEY] = shared.Files(part.content, filename)
        else:
            content = part.content.decode()
            if name in form_data:
                if isinstance(form_data[name], list):
                    form_data[name].append(content)
                else:
                    form_data[name] = [form_data[name], content]
            else:
                form_data[name] = content

    return form_data