about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py')
-rw-r--r--.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py445
1 files changed, 445 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py b/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py
new file mode 100644
index 00000000..54a3b89e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/split_pdf_hook.py
@@ -0,0 +1,445 @@
+from __future__ import annotations
+
+import asyncio
+import io
+import json
+import logging
+import math
+from collections.abc import Awaitable
+from typing import Any, Coroutine, Optional, Tuple, Union
+
+import httpx
+import nest_asyncio
+import requests
+from pypdf import PdfReader
+from requests_toolbelt.multipart.decoder import MultipartDecoder
+
+from unstructured_client._hooks.custom import form_utils, pdf_utils, request_utils
+from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
+from unstructured_client._hooks.custom.form_utils import (
+    PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
+    PARTITION_FORM_FILES_KEY,
+    PARTITION_FORM_PAGE_RANGE_KEY,
+    PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
+    PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
+    PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
+)
+from unstructured_client._hooks.types import (
+    AfterErrorContext,
+    AfterErrorHook,
+    AfterSuccessContext,
+    AfterSuccessHook,
+    BeforeRequestContext,
+    BeforeRequestHook,
+    SDKInitHook,
+)
+from unstructured_client.models import shared
+
+logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
+
+DEFAULT_STARTING_PAGE_NUMBER = 1
+DEFAULT_ALLOW_FAILED = False
+DEFAULT_CONCURRENCY_LEVEL = 8
+MAX_CONCURRENCY_LEVEL = 15
+MIN_PAGES_PER_SPLIT = 2
+MAX_PAGES_PER_SPLIT = 20
+
+
+async def _order_keeper(index: int, coro: Awaitable) -> Tuple[int, requests.Response]:
+    response = await coro
+    return index, response
+
+
+async def run_tasks(coroutines: list[Awaitable], allow_failed: bool = False) -> list[tuple[int, requests.Response]]:
+    if allow_failed:
+        responses = await asyncio.gather(*coroutines, return_exceptions=False)
+        return list(enumerate(responses, 1))
+    # TODO: replace with asyncio.TaskGroup for python >3.11 # pylint: disable=fixme
+    tasks = [asyncio.create_task(_order_keeper(index, coro)) for index, coro in enumerate(coroutines, 1)]
+    results = []
+    remaining_tasks = dict(enumerate(tasks, 1))
+    for future in asyncio.as_completed(tasks):
+        index, response = await future
+        if response.status_code != 200:
+            # cancel all remaining tasks
+            for remaining_task in remaining_tasks.values():
+                remaining_task.cancel()
+            results.append((index, response))
+            break
+        results.append((index, response))
+        # remove task from remaining_tasks that should be cancelled in case of failure
+        del remaining_tasks[index]
+    # return results in the original order
+    return sorted(results, key=lambda x: x[0])
+
+
+def context_is_uvloop():
+    """Return true if uvloop is installed and we're currently in a uvloop context. Our asyncio splitting code currently doesn't work under uvloop."""
+    try:
+        import uvloop  # pylint: disable=import-outside-toplevel
+        loop = asyncio.get_event_loop()
+        return isinstance(loop, uvloop.Loop)
+    except (ImportError, RuntimeError):
+        return False
+
+def get_optimal_split_size(num_pages: int, concurrency_level: int) -> int:
+    """Distributes pages to workers evenly based on the number of pages and desired concurrency level."""
+    if num_pages < MAX_PAGES_PER_SPLIT * concurrency_level:
+        split_size = math.ceil(num_pages / concurrency_level)
+    else:
+        split_size = MAX_PAGES_PER_SPLIT
+
+    return max(split_size, MIN_PAGES_PER_SPLIT)
+
+
+class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorHook):
+    """
+    A hook class that splits a PDF file into multiple pages and sends each page as
+    a separate request. This hook is designed to be used with an Speakeasy SDK.
+
+    Usage:
+    1. Create an instance of the `SplitPdfHook` class.
+    2. Register SDK Init, Before Request, After Success and After Error hooks.
+    """
+
+    def __init__(self) -> None:
+        self.client: Optional[requests.Session] = None
+        self.coroutines_to_execute: dict[
+            str, list[Coroutine[Any, Any, requests.Response]]
+        ] = {}
+        self.api_successful_responses: dict[str, list[requests.Response]] = {}
+        self.api_failed_responses: dict[str, list[requests.Response]] = {}
+        self.allow_failed: bool = DEFAULT_ALLOW_FAILED
+
+    def sdk_init(
+            self, base_url: str, client: requests.Session
+    ) -> Tuple[str, requests.Session]:
+        """Initializes Split PDF Hook.
+
+        Args:
+            base_url (str): URL of the API.
+            client (requests.Session): HTTP Client.
+
+        Returns:
+            Tuple[str, requests.Session]: The initialized SDK options.
+        """
+        self.client = client
+        return base_url, client
+
+
+    # pylint: disable=too-many-return-statements
+    def before_request(
+            self, hook_ctx: BeforeRequestContext, request: requests.PreparedRequest
+    ) -> Union[requests.PreparedRequest, Exception]:
+        """If `splitPdfPage` is set to `true` in the request, the PDF file is split into
+        separate pages. Each page is sent as a separate request in parallel. The last
+        page request is returned by this method. It will return the original request
+        when: `splitPdfPage` is set to `false`, the file is not a PDF, or the HTTP
+        has not been initialized.
+
+        Args:
+            hook_ctx (BeforeRequestContext): The hook context containing information about
+            the operation.
+            request (requests.PreparedRequest): The request object.
+
+        Returns:
+            Union[requests.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
+            the last page request; otherwise, the original request.
+        """
+        if self.client is None:
+            logger.warning("HTTP client not accessible! Continuing without splitting.")
+            return request
+
+        if context_is_uvloop():
+            logger.warning("Splitting is currently incompatible with uvloop. Continuing without splitting.")
+            return request
+
+        # This allows us to use an event loop in an env with an existing loop
+        # Temporary fix until we can improve the async splitting behavior
+        nest_asyncio.apply()
+        operation_id = hook_ctx.operation_id
+        content_type = request.headers.get("Content-Type")
+        body = request.body
+        if not isinstance(body, bytes) or content_type is None:
+            return request
+
+        decoded_body = MultipartDecoder(body, content_type)
+        form_data = form_utils.parse_form_data(decoded_body)
+        split_pdf_page = form_data.get(PARTITION_FORM_SPLIT_PDF_PAGE_KEY)
+        if split_pdf_page is None or split_pdf_page == "false":
+            logger.info("Partitioning without split.")
+            return request
+
+        logger.info("Preparing to split document for partition.")
+        file = form_data.get(PARTITION_FORM_FILES_KEY)
+        if (
+                file is None
+                or not isinstance(file, shared.Files)
+                or not pdf_utils.is_pdf(file)
+        ):
+            logger.info("Partitioning without split.")
+            return request
+
+        starting_page_number = form_utils.get_starting_page_number(
+            form_data,
+            key=PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
+            fallback_value=DEFAULT_STARTING_PAGE_NUMBER,
+        )
+        if starting_page_number > 1:
+            logger.info("Starting page number set to %d", starting_page_number)
+        logger.info("Starting page number set to %d", starting_page_number)
+
+        self.allow_failed = form_utils.get_split_pdf_allow_failed_param(
+            form_data,
+            key=PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
+            fallback_value=DEFAULT_ALLOW_FAILED,
+        )
+        logger.info("Allow failed set to %d", self.allow_failed)
+
+        concurrency_level = form_utils.get_split_pdf_concurrency_level_param(
+            form_data,
+            key=PARTITION_FORM_CONCURRENCY_LEVEL_KEY,
+            fallback_value=DEFAULT_CONCURRENCY_LEVEL,
+            max_allowed=MAX_CONCURRENCY_LEVEL,
+        )
+        logger.info("Concurrency level set to %d", concurrency_level)
+        limiter = asyncio.Semaphore(concurrency_level)
+
+        pdf = PdfReader(io.BytesIO(file.content))
+
+        page_range_start, page_range_end = form_utils.get_page_range(
+            form_data,
+            key=PARTITION_FORM_PAGE_RANGE_KEY,
+            max_pages=len(pdf.pages),
+        )
+
+        page_count = page_range_end - page_range_start + 1
+        logger.info(
+            "Splitting pages %d to %d (%d total)",
+            page_range_start,
+            page_range_end,
+            page_count,
+        )
+
+        split_size = get_optimal_split_size(
+            num_pages=page_count, concurrency_level=concurrency_level
+        )
+        logger.info("Determined optimal split size of %d pages.", split_size)
+
+        # If the doc is small enough, and we aren't slicing it with a page range:
+        # do not split, just continue with the original request
+        if split_size >= page_count and page_count == len(pdf.pages):
+            logger.info(
+                "Document has too few pages (%d) to be split efficiently. Partitioning without split.",
+                page_count,
+            )
+            return request
+
+        pages = pdf_utils.get_pdf_pages(pdf, split_size=split_size, page_start=page_range_start, page_end=page_range_end)
+        logger.info(
+            "Partitioning %d files with %d page(s) each.",
+            math.floor(page_count / split_size),
+            split_size,
+        )
+
+        # Log the remainder pages if there are any
+        if page_count % split_size > 0:
+            logger.info(
+                "Partitioning 1 file with %d page(s).",
+                page_count % split_size,
+            )
+
+        async def call_api_partial(page):
+            async with httpx.AsyncClient() as client:
+                status_code, json_response = await request_utils.call_api_async(
+                    client=client,
+                    original_request=request,
+                    form_data=form_data,
+                    filename=file.file_name,
+                    page=page,
+                    limiter=limiter,
+                )
+
+                # convert httpx response to requests.Response to preserve
+                # compatibility with the synchronous SDK generated by speakeasy
+                response = requests.Response()
+                response.status_code = status_code
+                response._content = json.dumps(  # pylint: disable=W0212
+                    json_response
+                ).encode()
+                response.headers["Content-Type"] = "application/json"
+                return response
+
+        self.coroutines_to_execute[operation_id] = []
+        last_page_content = io.BytesIO()
+        last_page_number = 0
+        set_index = 1
+        for page_content, page_index, all_pages_number in pages:
+            page_number = page_index + starting_page_number
+            logger.info(
+                "Partitioning set #%d (pages %d-%d).",
+                set_index,
+                page_number,
+                min(page_number + split_size - 1, all_pages_number),
+            )
+            # Check if this set of pages is the last one
+            if page_index + split_size >= all_pages_number:
+                last_page_content = page_content
+                last_page_number = page_number
+                break
+            coroutine = call_api_partial((page_content, page_number))
+            self.coroutines_to_execute[operation_id].append(coroutine)
+            set_index += 1
+        # `before_request` method needs to return a request so we skip sending the last page in parallel
+        # and return that last page at the end of this method
+
+        body = request_utils.create_request_body(
+            form_data, last_page_content, file.file_name, last_page_number
+        )
+        last_page_request = request_utils.create_request(request, body)
+        last_page_prepared_request = self.client.prepare_request(last_page_request)
+        return last_page_prepared_request
+
+    def _await_elements(
+            self, operation_id: str, response: requests.Response
+    ) -> Optional[list]:
+        """
+        Waits for the partition requests to complete and returns the flattened
+        elements.
+
+        Args:
+            operation_id (str): The ID of the operation.
+            response (requests.Response): The response object.
+
+        Returns:
+            Optional[list]: The flattened elements if the partition requests are
+            completed, otherwise None.
+        """
+        tasks = self.coroutines_to_execute.get(operation_id)
+        if tasks is None:
+            return None
+
+        ioloop = asyncio.get_event_loop()
+        task_responses: list[tuple[int, requests.Response]] = ioloop.run_until_complete(
+            run_tasks(tasks, allow_failed=self.allow_failed)
+        )
+
+        if task_responses is None:
+            return None
+
+        successful_responses = []
+        failed_responses = []
+        elements = []
+        for response_number, res in task_responses:
+            request_utils.log_after_split_response(res.status_code, response_number)
+            if res.status_code == 200:
+                successful_responses.append(res)
+                elements.append(res.json())
+            else:
+                failed_responses.append(res)
+
+        if self.allow_failed or not failed_responses:
+            last_response_number = len(task_responses) + 1
+            request_utils.log_after_split_response(
+                response.status_code, last_response_number
+            )
+            if response.status_code == 200:
+                elements.append(response.json())
+                successful_responses.append(response)
+            else:
+                failed_responses.append(response)
+
+        self.api_successful_responses[operation_id] = successful_responses
+        self.api_failed_responses[operation_id] = failed_responses
+        flattened_elements = [element for sublist in elements for element in sublist]
+        return flattened_elements
+
+    def after_success(
+            self, hook_ctx: AfterSuccessContext, response: requests.Response
+    ) -> Union[requests.Response, Exception]:
+        """Executes after a successful API request. Awaits all parallel requests and
+        combines the responses into a single response object.
+
+        Args:
+            hook_ctx (AfterSuccessContext): The context object containing information
+            about the hook execution.
+            response (requests.Response): The response object returned from the API
+            request.
+
+        Returns:
+            Union[requests.Response, Exception]: If requests were run in parallel, a
+            combined response object; otherwise, the original response. Can return
+            exception if it ocurred during the execution.
+        """
+        operation_id = hook_ctx.operation_id
+        # Because in `before_request` method we skipped sending last page in parallel
+        # we need to pass response, which contains last page, to `_await_elements` method
+        elements = self._await_elements(operation_id, response)
+
+        # if fails are disallowed, return the first failed response
+        if not self.allow_failed and self.api_failed_responses.get(operation_id):
+            return self.api_failed_responses[operation_id][0]
+
+        if elements is None:
+            return response
+
+        updated_response = request_utils.create_response(response, elements)
+        self._clear_operation(operation_id)
+        return updated_response
+
+    def after_error(
+            self,
+            hook_ctx: AfterErrorContext,
+            response: Optional[requests.Response],
+            error: Optional[Exception],
+    ) -> Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]:
+        """Executes after an unsuccessful API request. Awaits all parallel requests,
+        if at least one request was successful, combines the responses into a single
+        response object and doesn't throw an error. It will return an error only if
+        all requests failed, or there was no PDF split.
+
+        Args:
+            hook_ctx (AfterErrorContext): The AfterErrorContext object containing
+            information about the hook context.
+            response (Optional[requests.Response]): The Response object representing
+            the response received before the exception occurred.
+            error (Optional[Exception]): The exception object that was thrown.
+
+        Returns:
+            Union[Tuple[Optional[requests.Response], Optional[Exception]], Exception]:
+            If requests were run in parallel, and at least one was successful, a combined
+            response object; otherwise, the original response and exception.
+        """
+
+        # if fails are disallowed - return response and error objects immediately
+        if not self.allow_failed:
+            return (response, error)
+
+        operation_id = hook_ctx.operation_id
+        # We know that this request failed so we pass a failed or empty response to `_await_elements` method
+        # where it checks if at least on of the other requests succeeded
+        elements = self._await_elements(operation_id, response or requests.Response())
+        successful_responses = self.api_successful_responses.get(operation_id)
+
+        if elements is None or successful_responses is None:
+            return (response, error)
+
+        if len(successful_responses) == 0:
+            self._clear_operation(operation_id)
+            return (response, error)
+
+        updated_response = request_utils.create_response(
+            successful_responses[0], elements
+        )
+        self._clear_operation(operation_id)
+        return (updated_response, None)
+
+    def _clear_operation(self, operation_id: str) -> None:
+        """
+        Clears the operation data associated with the given operation ID.
+
+        Args:
+            operation_id (str): The ID of the operation to clear.
+        """
+        self.coroutines_to_execute.pop(operation_id, None)
+        self.api_successful_responses.pop(operation_id, None)