diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/formparsers.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/formparsers.py | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/formparsers.py b/.venv/lib/python3.12/site-packages/starlette/formparsers.py new file mode 100644 index 00000000..4551d688 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/formparsers.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass, field +from enum import Enum +from tempfile import SpooledTemporaryFile +from urllib.parse import unquote_plus + +from starlette.datastructures import FormData, Headers, UploadFile + +if typing.TYPE_CHECKING: + import python_multipart as multipart + from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header +else: + try: + try: + import python_multipart as multipart + from python_multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + import multipart + from multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + multipart = None + parse_options_header = None + + +class FormMessage(Enum): + FIELD_START = 1 + FIELD_NAME = 2 + FIELD_DATA = 3 + FIELD_END = 4 + END = 5 + + +@dataclass +class MultipartPart: + content_disposition: bytes | None = None + field_name: str = "" + data: bytearray = field(default_factory=bytearray) + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) + + +def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: + try: + return src.decode(codec) + except (UnicodeDecodeError, LookupError): + return src.decode("latin-1") + + +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + +class FormParser: + def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.messages: list[tuple[FormMessage, bytes]] = [] + + def on_field_start(self) -> None: + message = (FormMessage.FIELD_START, b"") + self.messages.append(message) + + def on_field_name(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_NAME, data[start:end]) + self.messages.append(message) + + def on_field_data(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_DATA, data[start:end]) + self.messages.append(message) + + def on_field_end(self) -> None: + message = (FormMessage.FIELD_END, b"") + self.messages.append(message) + + def on_end(self) -> None: + message = (FormMessage.END, b"") + self.messages.append(message) + + async def parse(self) -> FormData: + # Callbacks dictionary. + callbacks: QuerystringCallbacks = { + "on_field_start": self.on_field_start, + "on_field_name": self.on_field_name, + "on_field_data": self.on_field_data, + "on_field_end": self.on_field_end, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.QuerystringParser(callbacks) + field_name = b"" + field_value = b"" + + items: list[tuple[str, str | UploadFile]] = [] + + # Feed the parser with data from the request. + async for chunk in self.stream: + if chunk: + parser.write(chunk) + else: + parser.finalize() + messages = list(self.messages) + self.messages.clear() + for message_type, message_bytes in messages: + if message_type == FormMessage.FIELD_START: + field_name = b"" + field_value = b"" + elif message_type == FormMessage.FIELD_NAME: + field_name += message_bytes + elif message_type == FormMessage.FIELD_DATA: + field_value += message_bytes + elif message_type == FormMessage.FIELD_END: + name = unquote_plus(field_name.decode("latin-1")) + value = unquote_plus(field_value.decode("latin-1")) + items.append((name, value)) + + return FormData(items) + + +class MultiPartParser: + spool_max_size = 1024 * 1024 # 1MB + """The maximum size of the spooled temporary file used to store file data.""" + max_part_size = 1024 * 1024 # 1MB + """The maximum size of a part in the multipart request.""" + + def __init__( + self, + headers: Headers, + stream: typing.AsyncGenerator[bytes, None], + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, # 1MB + ) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.max_files = max_files + self.max_fields = max_fields + self.items: list[tuple[str, str | UploadFile]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset = "" + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self.max_part_size = max_part_size + + def on_part_begin(self) -> None: + self._current_part = MultipartPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + message_bytes = data[start:end] + if self._current_part.file is None: + if len(self._current_part.data) + len(message_bytes) > self.max_part_size: + raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") + self._current_part.data.extend(message_bytes) + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) + + def on_part_end(self) -> None: + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append((field, self._current_partial_header_value)) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + disposition, options = parse_options_header(self._current_part.content_disposition) + try: + self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) + except KeyError: + raise MultiPartException('The Content-Disposition header field "name" must be provided.') + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") + filename = _user_safe_decode(options[b"filename"], self._charset) + tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) + self._files_to_close_on_error.append(tempfile) + self._current_part.file = UploadFile( + file=tempfile, # type: ignore[arg-type] + size=0, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") + self._current_part.file = None + + def on_end(self) -> None: + pass + + async def parse(self) -> FormData: + # Parse the Content-Type header to get the multipart boundary. + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") + + # Callbacks dictionary. + callbacks: MultipartCallbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.MultipartParser(boundary, callbacks) + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + for file in self._files_to_close_on_error: + file.close() + raise exc + + parser.finalize() + return FormData(self.items) |