from __future__ import annotations
import asyncio
import io
import json
import logging
import math
from collections.abc import Awaitable
from typing import Any, Coroutine, Optional, Tuple, Union
import httpx
import nest_asyncio
import requests
from pypdf import PdfReader
from requests_toolbelt.multipart.decoder import MultipartDecoder
from unstructured_client._hooks.custom import form_utils, pdf_utils, request_utils
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
from unstructured_client._hooks.custom.form_utils import (
PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
PARTITION_FORM_FILES_KEY,
PARTITION_FORM_PAGE_RANGE_KEY,
PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
)
from unstructured_client._hooks.types import (
AfterErrorContext,
AfterErrorHook,
AfterSuccessContext,
AfterSuccessHook,
BeforeRequestContext,
BeforeRequestHook,
SDKInitHook,
)
from unstructured_client.models import shared
logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
DEFAULT_STARTING_PAGE_NUMBER = 1
DEFAULT_ALLOW_FAILED = False
DEFAULT_CONCURRENCY_LEVEL = 8
MAX_CONCURRENCY_LEVEL = 15
MIN_PAGES_PER_SPLIT = 2
MAX_PAGES_PER_SPLIT = 20
async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, requests.Response]:
response = await coro
return index, response
async def run_tasks(coroutines: list[Awaitable], allow_failed: bool = False) -> list[tuple[int, requests.Response]]:
if allow_failed:
responses = await asyncio.gather(*coroutines, return_exceptions=False)
return list(enumerate(responses, 1))
# TODO: replace with asyncio.TaskGroup for python >3.11 # pylint: disable=fixme
tasks = [asyncio.create_task(_order_keeper(index, coro)) for index, coro in enumerate(coroutines, 1)]
results = []
remaining_tasks = dict(enumerate(tasks, 1))
for future in asyncio.as_completed(tasks):
index, response = await future
if response.status_code != 200:
# cancel all remaining tasks
for remaining_task in remaining_tasks.values():
remaining_task.cancel()
results.append((index, response))
break
results.append((index, response))
# remove task from remaining_tasks that should be cancelled in case of failure
del remaining_tasks[index]
# return results in the original order
return sorted(results, key=lambda x: x[0])
def context_is_uvloop():
"""Return true if uvloop is installed and we're currently in a uvloop context. Our asyncio splitting code currently doesn't work under uvloop."""
try:
import uvloop # pylint: disable=import-outside-toplevel
loop = asyncio.get_event_loop()
return isinstance(loop, uvloop.Loop)
except (ImportError, RuntimeError):
return False
def get_optimal_split_size(num_pages: int, concurrency_level: int) -> int:
"""Distributes pages to workers evenly based on the number of pages and desired concurrency level."""
if num_pages < MAX_PAGES_PER_SPLIT * concurrency_level:
split_size = math.ceil(num_pages / concurrency_level)
else:
split_size = MAX_PAGES_PER_SPLIT
return max(split_size, MIN_PAGES_PER_SPLIT)
class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
"""
A hook class that splits a PDF file into multiple pages and sends each page as
a separate request. This hook is designed to be used with an Speakeasy SDK.
Usage:
1. Create an instance of the `SplitPdfHook` class.
2. Register SDK Init, Before Request, After Success and After Error hooks.
"""
def __init__(self) -> None:
self.client: Optional[requests.Session] = None
self.coroutines_to_execute: dict[
str, list[Coroutine[Any, Any, requests.Response]]
] = {}
self.api_successful_responses: dict[str, list[requests.Response]] = {}
self.api_failed_responses: dict[str, list[requests.Response]] = {}
self.allow_failed: bool = DEFAULT_ALLOW_FAILED
def sdk_init(
self, base_url: str, client: requests.Session
) -> Tuple[str, requests.Session]:
"""Initializes Split PDF Hook.
Args:
base_url (str): URL of the API.
client (requests.Session): HTTP Client.
Returns:
Tuple[str, requests.Session]: The initialized SDK options.
"""
self.client = client
return base_url, client
# pylint: disable=too-many-return-statements
def before_request(
self, hook_ctx: BeforeRequestContext, request: requests.PreparedRequest
) -> Union[requests.PreparedRequest, Exception]:
"""If `splitPdfPage` is set to `true` in the request, the PDF file is split into
separate pages. Each page is sent as a separate request in parallel. The last
page request is returned by this method. It will return the original request
when: `splitPdfPage` is set to `false`, the file is not a PDF, or the HTTP
has not been initialized.
Args:
hook_ctx (BeforeRequestContext): The hook context containing information about
the operation.
request (requests.PreparedRequest): The request object.
Returns:
Union[requests.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
the last page request; otherwise, the original request.
"""
if self.client is None:
logger.warning("HTTP client not accessible! Continuing without splitting.")
return request
if context_is_uvloop():
logger.warning("Splitting is currently incompatible with uvloop. Continuing without splitting.")
return request
# This allows us to use an event loop in an env with an existing loop
# Temporary fix until we can improve the async splitting behavior
nest_asyncio.apply()
operation_id = hook_ctx.operation_id
content_type = request.headers.get("Content-Type")
body = request.body
if not isinstance(body, bytes) or content_type is None:
return request
decoded_body = MultipartDecoder(body, content_type)
form_data = form_utils.parse_form_data(decoded_body)
split_pdf_page = form_data.get(PARTITION_FORM_SPLIT_PDF_PAGE_KEY)
if split_pdf_page is None or split_pdf_page == "false":
logger.info("Partitioning without split.")
return request
logger.info("Preparing to split document for partition.")
file = form_data.get(PARTITION_FORM_FILES_KEY)
if (
file is None
or not isinstance(file, shared.Files)
or not pdf_utils.is_pdf(file)
):
logger.info("Partitioning without split.")
return request
starting_page_number = form_utils.get_starting_page_number(
form_data,
key=PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
fallback_value=DEFAULT_STARTING_PAGE_NUMBER,
)
if starting_page_number > 1:
logger.info("Starting page number set to %d", starting_page_number)
logger.info("Starting page number set to %d", starting_page_number)
self.allow_failed = form_utils.get_split_pdf_allow_failed_param(
form_data,
key=PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
fallback_value=DEFAULT_ALLOW_FAILED,
)
logger.info("Allow failed set to %d", self.allow_failed)
concurrency_level = form_utils.get_split_pdf_concurrency_level_param(
form_data,
key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
fallback_value=DEFAULT_CONCURRENCY_LEVEL,
max_allowed=MAX_CONCURRENCY_LEVEL,
)
logger.info("Concurrency level set to %d", concurrency_level)
limiter = asyncio.Semaphore(concurrency_level)
pdf = PdfReader(io.BytesIO(file.content))
page_range_start, page_range_end = form_utils.get_page_range(
form_data,
key=PARTITION_FORM_PAGE_RANGE_KEY,
max_pages=len(pdf.pages),
)
page_count = page_range_end - page_range_start + 1
logger.info(
"Splitting pages %d to %d (%d total)",
page_range_start,
page_range_end,
page_count,
)
split_size = get_optimal_split_size(
num_pages=page_count, concurrency_level=concurrency_level
)
logger.info("Determined optimal split size of %d pages.", split_size)
# If the doc is small enough, and we aren't slicing it with a page range:
# do not split, just continue with the original request
if split_size >= page_count and page_count == len(pdf.pages):
logger.info(
"Document has too few pages (%d) to be split efficiently. Partitioning without split.",
page_count,
)
return request
pages = pdf_utils.get_pdf_pages(pdf, split_size=split_size, page_start=page_range_start, page_end=page_range_end)
logger.info(
"Partitioning %d files with %d page(s) each.",
math.floor(page_count / split_size),
split_size,
)
# Log the remainder pages if there are any
if page_count % split_size > 0:
logger.info(
"Partitioning 1 file with %d page(s).",
page_count % split_size,
)
async def call_api_partial(page):
async with httpx.AsyncClient() as client:
status_code, json_response = await request_utils.call_api_async(
client=client,
original_request=request,
form_data=form_data,
filename=file.file_name,
page=page,
limiter=limiter,
)
# convert httpx response to requests.Response to preserve
# compatibility with the synchronous SDK generated by speakeasy
response = requests.Response()
response.status_code = status_code
response._content = json.dumps( # pylint: disable=W0212
json_response
).encode()
response.headers["Content-Type"] = "application/json"
return response
self.coroutines_to_execute[operation_id] = []
last_page_content = io.BytesIO()
last_page_number = 0
set_index = 1
for page_content, page_index, all_pages_number in pages:
page_number = page_index + starting_page_number
logger.info(
"Partitioning set #%d (pages %d-%d).",
set_index,
page_number,
min(page_number + split_size - 1, all_pages_number),
)
# Check if this set of pages is the last one
if page_index + split_size >= all_pages_number:
last_page_content = page_content
last_page_number = page_number
break
coroutine = call_api_partial((page_content, page_number))
self.coroutines_to_execute[operation_id].append(coroutine)
set_index += 1
# `before_request` method needs to return a request so we skip sending the last page in parallel
# and return that last page at the end of this method
body = request_utils.create_request_body(
form_data, last_page_content, file.file_name, last_page_number
)
last_page_request = request_utils.create_request(request, body)
last_page_prepared_request = self.client.prepare_request(last_page_request)
return last_page_prepared_request
def _await_elements(
self, operation_id: str, response: requests.Response
) -> Optional[list]:
"""
Waits for the partition requests to complete and returns the flattened
elements.
Args:
operation_id (str): The ID of the operation.
response (requests.Response): The response object.
Returns:
Optional[list]: The flattened elements if the partition requests are
completed, otherwise None.
"""
tasks = self.coroutines_to_execute.get(operation_id)
if tasks is None:
return None
ioloop = asyncio.get_event_loop()
task_responses: list[tuple[int, requests.Response]] = ioloop.run_until_complete(
run_tasks(tasks, allow_failed=self.allow_failed)
)
if task_responses is None:
return None
successful_responses = []
failed_responses = []
elements = []
for response_number, res in task_responses:
request_utils.log_after_split_response(res.status_code, response_number)
if res.status_code == 200:
successful_responses.append(res)
elements.append(res.json())
else:
failed_responses.append(res)
if self.allow_failed or not failed_responses:
last_response_number = len(task_responses) + 1
request_utils.log_after_split_response(
response.status_code, last_response_number
)
if response.status_code == 200:
elements.append(response.json())
successful_responses.append(response)
else:
failed_responses.append(response)
self.api_successful_responses[operation_id] = successful_responses
self.api_failed_responses[operation_id] = failed_responses
flattened_elements = [element for sublist in elements for element in sublist]
return flattened_elements
def after_success(
self, hook_ctx: AfterSuccessContext, response: requests.Response
) -> Union[requests.Response, Exception]:
"""Executes after a successful API request. Awaits all parallel requests and
combines the responses into a single response object.
Args:
hook_ctx (AfterSuccessContext): The context object containing information
about the hook execution.
response (requests.Response): The response object returned from the API
request.
Returns:
Union[requests.Response, Exception]: If requests were run in parallel, a
combined response object; otherwise, the original response. Can return
exception if it ocurred during the execution.
"""
operation_id = hook_ctx.operation_id
# Because in `before_request` method we skipped sending last page in parallel
# we need to pass response, which contains last page, to `_await_elements` method
elements = self._await_elements(operation_id, response)
# if fails are disallowed, return the first failed response
if not self.allow_failed and self.api_failed_responses.get(operation_id):
return self.api_failed_responses[operation_id][0]
if elements is None:
return response
updated_response = request_utils.create_response(response, elements)
self._clear_operation(operation_id)
return updated_response
def after_error(
self,
hook_ctx: AfterErrorContext,
response: Optional[requests.Response],
error: Optional[Exception],
) -> Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]:
"""Executes after an unsuccessful API request. Awaits all parallel requests,
if at least one request was successful, combines the responses into a single
response object and doesn't throw an error. It will return an error only if
all requests failed, or there was no PDF split.
Args:
hook_ctx (AfterErrorContext): The AfterErrorContext object containing
information about the hook context.
response (Optional[requests.Response]): The Response object representing
the response received before the exception occurred.
error (Optional[Exception]): The exception object that was thrown.
Returns:
Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]:
If requests were run in parallel, and at least one was successful, a combined
response object; otherwise, the original response and exception.
"""
# if fails are disallowed - return response and error objects immediately
if not self.allow_failed:
return (response, error)
operation_id = hook_ctx.operation_id
# We know that this request failed so we pass a failed or empty response to `_await_elements` method
# where it checks if at least on of the other requests succeeded
elements = self._await_elements(operation_id, response or requests.Response())
successful_responses = self.api_successful_responses.get(operation_id)
if elements is None or successful_responses is None:
return (response, error)
if len(successful_responses) == 0:
self._clear_operation(operation_id)
return (response, error)
updated_response = request_utils.create_response(
successful_responses[0], elements
)
self._clear_operation(operation_id)
return (updated_response, None)
def _clear_operation(self, operation_id: str) -> None:
"""
Clears the operation data associated with the given operation ID.
Args:
operation_id (str): The ID of the operation to clear.
"""
self.coroutines_to_execute.pop(operation_id, None)
self.api_successful_responses.pop(operation_id, None)