aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py b/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py
new file mode 100644
index 00000000..1512e80b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/unstructured_client/_hooks/custom/request_utils.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import asyncio
+import copy
+import io
+import json
+import logging
+from typing import Optional, Tuple, Any
+
+import httpx
+import requests
+from requests.structures import CaseInsensitiveDict
+from requests_toolbelt.multipart.encoder import MultipartEncoder
+
+from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
+from unstructured_client._hooks.custom.form_utils import (
+ PARTITION_FORM_FILES_KEY,
+ PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
+ PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY,
+ PARTITION_FORM_PAGE_RANGE_KEY,
+ PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
+ FormData,
+)
+
+logger = logging.getLogger(UNSTRUCTURED_CLIENT_LOGGER_NAME)
+
+
+def create_request_body(
+ form_data: FormData, page_content: io.BytesIO, filename: str, page_number: int
+) -> MultipartEncoder:
+ payload = prepare_request_payload(form_data)
+
+ payload_fields: list[tuple[str, Any]] = []
+ for key, value in payload.items():
+ if isinstance(value, list):
+ payload_fields.extend([(key, list_value) for list_value in value])
+ else:
+ payload_fields.append((key, value))
+
+ payload_fields.append((PARTITION_FORM_FILES_KEY, (
+ filename,
+ page_content,
+ "application/pdf",
+ )))
+
+ payload_fields.append((PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, str(page_number)))
+
+ body = MultipartEncoder(
+ fields=payload_fields
+ )
+ return body
+
+
+def create_httpx_request(
+ original_request: requests.Request, body: MultipartEncoder
+) -> httpx.Request:
+ headers = prepare_request_headers(original_request.headers)
+ return httpx.Request(
+ method="POST",
+ url=original_request.url or "",
+ content=body.to_string(),
+ headers={**headers, "Content-Type": body.content_type},
+ )
+
+
+def create_request(
+ request: requests.PreparedRequest,
+ body: MultipartEncoder,
+) -> requests.Request:
+ headers = prepare_request_headers(request.headers)
+ return requests.Request(
+ method="POST",
+ url=request.url or "",
+ data=body,
+ headers={**headers, "Content-Type": body.content_type},
+ )
+
+
+async def call_api_async(
+ client: httpx.AsyncClient,
+ page: Tuple[io.BytesIO, int],
+ original_request: requests.Request,
+ form_data: FormData,
+ filename: str,
+ limiter: asyncio.Semaphore,
+) -> tuple[int, dict]:
+ page_content, page_number = page
+ body = create_request_body(form_data, page_content, filename, page_number)
+ new_request = create_httpx_request(original_request, body)
+ async with limiter:
+ try:
+ response = await client.send(new_request)
+ return response.status_code, response.json()
+ except Exception:
+ logger.error("Failed to send request for page %d", page_number)
+ return 500, {}
+
+
+def call_api(
+ client: Optional[requests.Session],
+ page: Tuple[io.BytesIO, int],
+ request: requests.PreparedRequest,
+ form_data: FormData,
+ filename: str,
+) -> requests.Response:
+ if client is None:
+ raise RuntimeError("HTTP client not accessible!")
+ page_content, page_number = page
+
+ body = create_request_body(form_data, page_content, filename, page_number)
+ new_request = create_request(request, body)
+ prepared_request = client.prepare_request(new_request)
+
+ try:
+ return client.send(prepared_request)
+ except Exception:
+ logger.error("Failed to send request for page %d", page_number)
+ return requests.Response()
+
+
+def prepare_request_headers(
+ headers: CaseInsensitiveDict[str],
+) -> CaseInsensitiveDict[str]:
+ """Prepare the request headers by removing the 'Content-Type' and 'Content-Length' headers.
+
+ Args:
+ headers: The original request headers.
+
+ Returns:
+ The modified request headers.
+ """
+ headers = copy.deepcopy(headers)
+ headers.pop("Content-Type", None)
+ headers.pop("Content-Length", None)
+ return headers
+
+
+def prepare_request_payload(form_data: FormData) -> FormData:
+ """Prepares the request payload by removing unnecessary keys and updating the file.
+
+ Args:
+ form_data: The original form data.
+
+ Returns:
+ The updated request payload.
+ """
+ payload = copy.deepcopy(form_data)
+ payload.pop(PARTITION_FORM_SPLIT_PDF_PAGE_KEY, None)
+ payload.pop(PARTITION_FORM_SPLIT_PDF_ALLOW_FAILED_KEY, None)
+ payload.pop(PARTITION_FORM_FILES_KEY, None)
+ payload.pop(PARTITION_FORM_PAGE_RANGE_KEY, None)
+ payload.pop(PARTITION_FORM_STARTING_PAGE_NUMBER_KEY, None)
+ updated_parameters = {
+ PARTITION_FORM_SPLIT_PDF_PAGE_KEY: "false",
+ }
+ payload.update(updated_parameters)
+ return payload
+
+
+def create_response(response: requests.Response, elements: list) -> requests.Response:
+ """
+ Creates a modified response object with updated content.
+
+ Args:
+ response: The original response object.
+ elements: The list of elements to be serialized and added to
+ the response.
+
+ Returns:
+ The modified response object with updated content.
+ """
+ response_copy = copy.deepcopy(response)
+ content = json.dumps(elements).encode()
+ content_length = str(len(content))
+ response_copy.headers.update({"Content-Length": content_length})
+ setattr(response_copy, "_content", content)
+ return response_copy
+
+
+def log_after_split_response(status_code: int, split_number: int):
+ if status_code == 200:
+ logger.info(
+ "Successfully partitioned set #%d, elements added to the final result.",
+ split_number,
+ )
+ else:
+ logger.warning(
+ "Failed to partition set #%d, its elements will be omitted in the final result.",
+ split_number,
+ )