about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/formparsers.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/formparsers.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/formparsers.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/formparsers.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/formparsers.py b/.venv/lib/python3.12/site-packages/starlette/formparsers.py
new file mode 100644
index 00000000..4551d688
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/formparsers.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+
+import typing
+from dataclasses import dataclass, field
+from enum import Enum
+from tempfile import SpooledTemporaryFile
+from urllib.parse import unquote_plus
+
+from starlette.datastructures import FormData, Headers, UploadFile
+
+if typing.TYPE_CHECKING:
+    import python_multipart as multipart
+    from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
+else:
+    try:
+        try:
+            import python_multipart as multipart
+            from python_multipart.multipart import parse_options_header
+        except ModuleNotFoundError:  # pragma: no cover
+            import multipart
+            from multipart.multipart import parse_options_header
+    except ModuleNotFoundError:  # pragma: no cover
+        multipart = None
+        parse_options_header = None
+
+
+class FormMessage(Enum):
+    FIELD_START = 1
+    FIELD_NAME = 2
+    FIELD_DATA = 3
+    FIELD_END = 4
+    END = 5
+
+
+@dataclass
+class MultipartPart:
+    content_disposition: bytes | None = None
+    field_name: str = ""
+    data: bytearray = field(default_factory=bytearray)
+    file: UploadFile | None = None
+    item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
+
+
+def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
+    try:
+        return src.decode(codec)
+    except (UnicodeDecodeError, LookupError):
+        return src.decode("latin-1")
+
+
+class MultiPartException(Exception):
+    def __init__(self, message: str) -> None:
+        self.message = message
+
+
+class FormParser:
+    def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+        assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
+        self.headers = headers
+        self.stream = stream
+        self.messages: list[tuple[FormMessage, bytes]] = []
+
+    def on_field_start(self) -> None:
+        message = (FormMessage.FIELD_START, b"")
+        self.messages.append(message)
+
+    def on_field_name(self, data: bytes, start: int, end: int) -> None:
+        message = (FormMessage.FIELD_NAME, data[start:end])
+        self.messages.append(message)
+
+    def on_field_data(self, data: bytes, start: int, end: int) -> None:
+        message = (FormMessage.FIELD_DATA, data[start:end])
+        self.messages.append(message)
+
+    def on_field_end(self) -> None:
+        message = (FormMessage.FIELD_END, b"")
+        self.messages.append(message)
+
+    def on_end(self) -> None:
+        message = (FormMessage.END, b"")
+        self.messages.append(message)
+
+    async def parse(self) -> FormData:
+        # Callbacks dictionary.
+        callbacks: QuerystringCallbacks = {
+            "on_field_start": self.on_field_start,
+            "on_field_name": self.on_field_name,
+            "on_field_data": self.on_field_data,
+            "on_field_end": self.on_field_end,
+            "on_end": self.on_end,
+        }
+
+        # Create the parser.
+        parser = multipart.QuerystringParser(callbacks)
+        field_name = b""
+        field_value = b""
+
+        items: list[tuple[str, str | UploadFile]] = []
+
+        # Feed the parser with data from the request.
+        async for chunk in self.stream:
+            if chunk:
+                parser.write(chunk)
+            else:
+                parser.finalize()
+            messages = list(self.messages)
+            self.messages.clear()
+            for message_type, message_bytes in messages:
+                if message_type == FormMessage.FIELD_START:
+                    field_name = b""
+                    field_value = b""
+                elif message_type == FormMessage.FIELD_NAME:
+                    field_name += message_bytes
+                elif message_type == FormMessage.FIELD_DATA:
+                    field_value += message_bytes
+                elif message_type == FormMessage.FIELD_END:
+                    name = unquote_plus(field_name.decode("latin-1"))
+                    value = unquote_plus(field_value.decode("latin-1"))
+                    items.append((name, value))
+
+        return FormData(items)
+
+
+class MultiPartParser:
+    spool_max_size = 1024 * 1024  # 1MB
+    """The maximum size of the spooled temporary file used to store file data."""
+    max_part_size = 1024 * 1024  # 1MB
+    """The maximum size of a part in the multipart request."""
+
+    def __init__(
+        self,
+        headers: Headers,
+        stream: typing.AsyncGenerator[bytes, None],
+        *,
+        max_files: int | float = 1000,
+        max_fields: int | float = 1000,
+        max_part_size: int = 1024 * 1024,  # 1MB
+    ) -> None:
+        assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
+        self.headers = headers
+        self.stream = stream
+        self.max_files = max_files
+        self.max_fields = max_fields
+        self.items: list[tuple[str, str | UploadFile]] = []
+        self._current_files = 0
+        self._current_fields = 0
+        self._current_partial_header_name: bytes = b""
+        self._current_partial_header_value: bytes = b""
+        self._current_part = MultipartPart()
+        self._charset = ""
+        self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
+        self._file_parts_to_finish: list[MultipartPart] = []
+        self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
+        self.max_part_size = max_part_size
+
+    def on_part_begin(self) -> None:
+        self._current_part = MultipartPart()
+
+    def on_part_data(self, data: bytes, start: int, end: int) -> None:
+        message_bytes = data[start:end]
+        if self._current_part.file is None:
+            if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
+                raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
+            self._current_part.data.extend(message_bytes)
+        else:
+            self._file_parts_to_write.append((self._current_part, message_bytes))
+
+    def on_part_end(self) -> None:
+        if self._current_part.file is None:
+            self.items.append(
+                (
+                    self._current_part.field_name,
+                    _user_safe_decode(self._current_part.data, self._charset),
+                )
+            )
+        else:
+            self._file_parts_to_finish.append(self._current_part)
+            # The file can be added to the items right now even though it's not
+            # finished yet, because it will be finished in the `parse()` method, before
+            # self.items is used in the return value.
+            self.items.append((self._current_part.field_name, self._current_part.file))
+
+    def on_header_field(self, data: bytes, start: int, end: int) -> None:
+        self._current_partial_header_name += data[start:end]
+
+    def on_header_value(self, data: bytes, start: int, end: int) -> None:
+        self._current_partial_header_value += data[start:end]
+
+    def on_header_end(self) -> None:
+        field = self._current_partial_header_name.lower()
+        if field == b"content-disposition":
+            self._current_part.content_disposition = self._current_partial_header_value
+        self._current_part.item_headers.append((field, self._current_partial_header_value))
+        self._current_partial_header_name = b""
+        self._current_partial_header_value = b""
+
+    def on_headers_finished(self) -> None:
+        disposition, options = parse_options_header(self._current_part.content_disposition)
+        try:
+            self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
+        except KeyError:
+            raise MultiPartException('The Content-Disposition header field "name" must be provided.')
+        if b"filename" in options:
+            self._current_files += 1
+            if self._current_files > self.max_files:
+                raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
+            filename = _user_safe_decode(options[b"filename"], self._charset)
+            tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
+            self._files_to_close_on_error.append(tempfile)
+            self._current_part.file = UploadFile(
+                file=tempfile,  # type: ignore[arg-type]
+                size=0,
+                filename=filename,
+                headers=Headers(raw=self._current_part.item_headers),
+            )
+        else:
+            self._current_fields += 1
+            if self._current_fields > self.max_fields:
+                raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
+            self._current_part.file = None
+
+    def on_end(self) -> None:
+        pass
+
+    async def parse(self) -> FormData:
+        # Parse the Content-Type header to get the multipart boundary.
+        _, params = parse_options_header(self.headers["Content-Type"])
+        charset = params.get(b"charset", "utf-8")
+        if isinstance(charset, bytes):
+            charset = charset.decode("latin-1")
+        self._charset = charset
+        try:
+            boundary = params[b"boundary"]
+        except KeyError:
+            raise MultiPartException("Missing boundary in multipart.")
+
+        # Callbacks dictionary.
+        callbacks: MultipartCallbacks = {
+            "on_part_begin": self.on_part_begin,
+            "on_part_data": self.on_part_data,
+            "on_part_end": self.on_part_end,
+            "on_header_field": self.on_header_field,
+            "on_header_value": self.on_header_value,
+            "on_header_end": self.on_header_end,
+            "on_headers_finished": self.on_headers_finished,
+            "on_end": self.on_end,
+        }
+
+        # Create the parser.
+        parser = multipart.MultipartParser(boundary, callbacks)
+        try:
+            # Feed the parser with data from the request.
+            async for chunk in self.stream:
+                parser.write(chunk)
+                # Write file data, it needs to use await with the UploadFile methods
+                # that call the corresponding file methods *in a threadpool*,
+                # otherwise, if they were called directly in the callback methods above
+                # (regular, non-async functions), that would block the event loop in
+                # the main thread.
+                for part, data in self._file_parts_to_write:
+                    assert part.file  # for type checkers
+                    await part.file.write(data)
+                for part in self._file_parts_to_finish:
+                    assert part.file  # for type checkers
+                    await part.file.seek(0)
+                self._file_parts_to_write.clear()
+                self._file_parts_to_finish.clear()
+        except MultiPartException as exc:
+            # Close all the files if there was an error.
+            for file in self._files_to_close_on_error:
+                file.close()
+            raise exc
+
+        parser.finalize()
+        return FormData(self.items)