from __future__ import annotations import asyncio import io import json import logging import math from collections.abc import Awaitable from typing import Any, Coroutine, Optional, Tuple, Union import httpx import nest_asyncio import requests from pypdf import PdfReader from requests_toolbelt.multipart.decoder import MultipartDecoder from unstructured_client._hooks.custom import form_utils, pdf_utils, request_utils from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME from unstructured_client._hooks.custom.form_utils import ( PARTITION_FORM_CONCURRENCY_LEVEL_KEY, PARTITION_FORM_FILES_KEY, PARTITION_FORM_PAGE_RANGE_KEY, PARTITION_FORM_SPLIT_PDF_PAGE_KEY, PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, ) from unstructured_client._hooks.types import ( AfterErrorContext, AfterErrorHook, AfterSuccessContext, AfterSuccessHook, BeforeRequestContext, BeforeRequestHook, SDKInitHook, ) from unstructured_client.models import shared logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME) DEFAULT_STARTING_PAGE_NUMBER = 1 DEFAULT_ALLOW_FAILED = False DEFAULT_CONCURRENCY_LEVEL = 8 MAX_CONCURRENCY_LEVEL = 15 MIN_PAGES_PER_SPLIT = 2 MAX_PAGES_PER_SPLIT = 20 async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, requests.Response]: response = await coro return index, response async def run_tasks(coroutines: list[Awaitable], allow_failed: bool = False) -> list[tuple[int, requests.Response]]: if allow_failed: responses = await asyncio.gather(*coroutines, return_exceptions=False) return list(enumerate(responses, 1)) # TODO: replace with asyncio.TaskGroup for python >3.11 # pylint: disable=fixme tasks = [asyncio.create_task(_order_keeper(index, coro)) for index, coro in enumerate(coroutines, 1)] results = [] remaining_tasks = dict(enumerate(tasks, 1)) for future in asyncio.as_completed(tasks): index, response = await future if response.status_code != 200: # cancel all remaining tasks for remaining_task in remaining_tasks.values(): remaining_task.cancel() results.append((index, response)) break results.append((index, response)) # remove task from remaining_tasks that should be cancelled in case of failure del remaining_tasks[index] # return results in the original order return sorted(results, key=lambda x: x[0]) def context_is_uvloop(): """Return true if uvloop is installed and we're currently in a uvloop context. Our asyncio splitting code currently doesn't work under uvloop.""" try: import uvloop # pylint: disable=import-outside-toplevel loop = asyncio.get_event_loop() return isinstance(loop, uvloop.Loop) except (ImportError, RuntimeError): return False def get_optimal_split_size(num_pages: int, concurrency_level: int) -> int: """Distributes pages to workers evenly based on the number of pages and desired concurrency level.""" if num_pages < MAX_PAGES_PER_SPLIT * concurrency_level: split_size = math.ceil(num_pages / concurrency_level) else: split_size = MAX_PAGES_PER_SPLIT return max(split_size, MIN_PAGES_PER_SPLIT) class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorHook): """ A hook class that splits a PDF file into multiple pages and sends each page as a separate request. This hook is designed to be used with an Speakeasy SDK. Usage: 1. Create an instance of the `SplitPdfHook` class. 2. Register SDK Init, Before Request, After Success and After Error hooks. """ def __init__(self) -> None: self.client: Optional[requests.Session] = None self.coroutines_to_execute: dict[ str, list[Coroutine[Any, Any, requests.Response]] ] = {} self.api_successful_responses: dict[str, list[requests.Response]] = {} self.api_failed_responses: dict[str, list[requests.Response]] = {} self.allow_failed: bool = DEFAULT_ALLOW_FAILED def sdk_init( self, base_url: str, client: requests.Session ) -> Tuple[str, requests.Session]: """Initializes Split PDF Hook. Args: base_url (str): URL of the API. client (requests.Session): HTTP Client. Returns: Tuple[str, requests.Session]: The initialized SDK options. """ self.client = client return base_url, client # pylint: disable=too-many-return-statements def before_request( self, hook_ctx: BeforeRequestContext, request: requests.PreparedRequest ) -> Union[requests.PreparedRequest, Exception]: """If `splitPdfPage` is set to `true` in the request, the PDF file is split into separate pages. Each page is sent as a separate request in parallel. The last page request is returned by this method. It will return the original request when: `splitPdfPage` is set to `false`, the file is not a PDF, or the HTTP has not been initialized. Args: hook_ctx (BeforeRequestContext): The hook context containing information about the operation. request (requests.PreparedRequest): The request object. Returns: Union[requests.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`, the last page request; otherwise, the original request. """ if self.client is None: logger.warning("HTTP client not accessible! Continuing without splitting.") return request if context_is_uvloop(): logger.warning("Splitting is currently incompatible with uvloop. Continuing without splitting.") return request # This allows us to use an event loop in an env with an existing loop # Temporary fix until we can improve the async splitting behavior nest_asyncio.apply() operation_id = hook_ctx.operation_id content_type = request.headers.get("Content-Type") body = request.body if not isinstance(body, bytes) or content_type is None: return request decoded_body = MultipartDecoder(body, content_type) form_data = form_utils.parse_form_data(decoded_body) split_pdf_page = form_data.get(PARTITION_FORM_SPLIT_PDF_PAGE_KEY) if split_pdf_page is None or split_pdf_page == "false": logger.info("Partitioning without split.") return request logger.info("Preparing to split document for partition.") file = form_data.get(PARTITION_FORM_FILES_KEY) if ( file is None or not isinstance(file, shared.Files) or not pdf_utils.is_pdf(file) ): logger.info("Partitioning without split.") return request starting_page_number = form_utils.get_starting_page_number( form_data, key=PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, fallback_value=DEFAULT_STARTING_PAGE_NUMBER, ) if starting_page_number > 1: logger.info("Starting page number set to %d", starting_page_number) logger.info("Starting page number set to %d", starting_page_number) self.allow_failed = form_utils.get_split_pdf_allow_failed_param( form_data, key=PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, fallback_value=DEFAULT_ALLOW_FAILED, ) logger.info("Allow failed set to %d", self.allow_failed) concurrency_level = form_utils.get_split_pdf_concurrency_level_param( form_data, key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY, fallback_value=DEFAULT_CONCURRENCY_LEVEL, max_allowed=MAX_CONCURRENCY_LEVEL, ) logger.info("Concurrency level set to %d", concurrency_level) limiter = asyncio.Semaphore(concurrency_level) pdf = PdfReader(io.BytesIO(file.content)) page_range_start, page_range_end = form_utils.get_page_range( form_data, key=PARTITION_FORM_PAGE_RANGE_KEY, max_pages=len(pdf.pages), ) page_count = page_range_end - page_range_start + 1 logger.info( "Splitting pages %d to %d (%d total)", page_range_start, page_range_end, page_count, ) split_size = get_optimal_split_size( num_pages=page_count, concurrency_level=concurrency_level ) logger.info("Determined optimal split size of %d pages.", split_size) # If the doc is small enough, and we aren't slicing it with a page range: # do not split, just continue with the original request if split_size >= page_count and page_count == len(pdf.pages): logger.info( "Document has too few pages (%d) to be split efficiently. Partitioning without split.", page_count, ) return request pages = pdf_utils.get_pdf_pages(pdf, split_size=split_size, page_start=page_range_start, page_end=page_range_end) logger.info( "Partitioning %d files with %d page(s) each.", math.floor(page_count / split_size), split_size, ) # Log the remainder pages if there are any if page_count % split_size > 0: logger.info( "Partitioning 1 file with %d page(s).", page_count % split_size, ) async def call_api_partial(page): async with httpx.AsyncClient() as client: status_code, json_response = await request_utils.call_api_async( client=client, original_request=request, form_data=form_data, filename=file.file_name, page=page, limiter=limiter, ) # convert httpx response to requests.Response to preserve # compatibility with the synchronous SDK generated by speakeasy response = requests.Response() response.status_code = status_code response._content = json.dumps( # pylint: disable=W0212 json_response ).encode() response.headers["Content-Type"] = "application/json" return response self.coroutines_to_execute[operation_id] = [] last_page_content = io.BytesIO() last_page_number = 0 set_index = 1 for page_content, page_index, all_pages_number in pages: page_number = page_index + starting_page_number logger.info( "Partitioning set #%d (pages %d-%d).", set_index, page_number, min(page_number + split_size - 1, all_pages_number), ) # Check if this set of pages is the last one if page_index + split_size >= all_pages_number: last_page_content = page_content last_page_number = page_number break coroutine = call_api_partial((page_content, page_number)) self.coroutines_to_execute[operation_id].append(coroutine) set_index += 1 # `before_request` method needs to return a request so we skip sending the last page in parallel # and return that last page at the end of this method body = request_utils.create_request_body( form_data, last_page_content, file.file_name, last_page_number ) last_page_request = request_utils.create_request(request, body) last_page_prepared_request = self.client.prepare_request(last_page_request) return last_page_prepared_request def _await_elements( self, operation_id: str, response: requests.Response ) -> Optional[list]: """ Waits for the partition requests to complete and returns the flattened elements. Args: operation_id (str): The ID of the operation. response (requests.Response): The response object. Returns: Optional[list]: The flattened elements if the partition requests are completed, otherwise None. """ tasks = self.coroutines_to_execute.get(operation_id) if tasks is None: return None ioloop = asyncio.get_event_loop() task_responses: list[tuple[int, requests.Response]] = ioloop.run_until_complete( run_tasks(tasks, allow_failed=self.allow_failed) ) if task_responses is None: return None successful_responses = [] failed_responses = [] elements = [] for response_number, res in task_responses: request_utils.log_after_split_response(res.status_code, response_number) if res.status_code == 200: successful_responses.append(res) elements.append(res.json()) else: failed_responses.append(res) if self.allow_failed or not failed_responses: last_response_number = len(task_responses) + 1 request_utils.log_after_split_response( response.status_code, last_response_number ) if response.status_code == 200: elements.append(response.json()) successful_responses.append(response) else: failed_responses.append(response) self.api_successful_responses[operation_id] = successful_responses self.api_failed_responses[operation_id] = failed_responses flattened_elements = [element for sublist in elements for element in sublist] return flattened_elements def after_success( self, hook_ctx: AfterSuccessContext, response: requests.Response ) -> Union[requests.Response, Exception]: """Executes after a successful API request. Awaits all parallel requests and combines the responses into a single response object. Args: hook_ctx (AfterSuccessContext): The context object containing information about the hook execution. response (requests.Response): The response object returned from the API request. Returns: Union[requests.Response, Exception]: If requests were run in parallel, a combined response object; otherwise, the original response. Can return exception if it ocurred during the execution. """ operation_id = hook_ctx.operation_id # Because in `before_request` method we skipped sending last page in parallel # we need to pass response, which contains last page, to `_await_elements` method elements = self._await_elements(operation_id, response) # if fails are disallowed, return the first failed response if not self.allow_failed and self.api_failed_responses.get(operation_id): return self.api_failed_responses[operation_id][0] if elements is None: return response updated_response = request_utils.create_response(response, elements) self._clear_operation(operation_id) return updated_response def after_error( self, hook_ctx: AfterErrorContext, response: Optional[requests.Response], error: Optional[Exception], ) -> Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]: """Executes after an unsuccessful API request. Awaits all parallel requests, if at least one request was successful, combines the responses into a single response object and doesn't throw an error. It will return an error only if all requests failed, or there was no PDF split. Args: hook_ctx (AfterErrorContext): The AfterErrorContext object containing information about the hook context. response (Optional[requests.Response]): The Response object representing the response received before the exception occurred. error (Optional[Exception]): The exception object that was thrown. Returns: Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]: If requests were run in parallel, and at least one was successful, a combined response object; otherwise, the original response and exception. """ # if fails are disallowed - return response and error objects immediately if not self.allow_failed: return (response, error) operation_id = hook_ctx.operation_id # We know that this request failed so we pass a failed or empty response to `_await_elements` method # where it checks if at least on of the other requests succeeded elements = self._await_elements(operation_id, response or requests.Response()) successful_responses = self.api_successful_responses.get(operation_id) if elements is None or successful_responses is None: return (response, error) if len(successful_responses) == 0: self._clear_operation(operation_id) return (response, error) updated_response = request_utils.create_response( successful_responses[0], elements ) self._clear_operation(operation_id) return (updated_response, None) def _clear_operation(self, operation_id: str) -> None: """ Clears the operation data associated with the given operation ID. Args: operation_id (str): The ID of the operation to clear. """ self.coroutines_to_execute.pop(operation_id, None) self.api_successful_responses.pop(operation_id, None)