aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/starlette/formparsers.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/formparsers.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/formparsers.py275
1 files changed, 275 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/formparsers.py b/.venv/lib/python3.12/site-packages/starlette/formparsers.py
new file mode 100644
index 00000000..4551d688
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/formparsers.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+
+import typing
+from dataclasses import dataclass, field
+from enum import Enum
+from tempfile import SpooledTemporaryFile
+from urllib.parse import unquote_plus
+
+from starlette.datastructures import FormData, Headers, UploadFile
+
+if typing.TYPE_CHECKING:
+ import python_multipart as multipart
+ from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
+else:
+ try:
+ try:
+ import python_multipart as multipart
+ from python_multipart.multipart import parse_options_header
+ except ModuleNotFoundError: # pragma: no cover
+ import multipart
+ from multipart.multipart import parse_options_header
+ except ModuleNotFoundError: # pragma: no cover
+ multipart = None
+ parse_options_header = None
+
+
+class FormMessage(Enum):
+ FIELD_START = 1
+ FIELD_NAME = 2
+ FIELD_DATA = 3
+ FIELD_END = 4
+ END = 5
+
+
+@dataclass
+class MultipartPart:
+ content_disposition: bytes | None = None
+ field_name: str = ""
+ data: bytearray = field(default_factory=bytearray)
+ file: UploadFile | None = None
+ item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
+
+
+def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
+ try:
+ return src.decode(codec)
+ except (UnicodeDecodeError, LookupError):
+ return src.decode("latin-1")
+
+
+class MultiPartException(Exception):
+ def __init__(self, message: str) -> None:
+ self.message = message
+
+
+class FormParser:
+ def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
+ self.headers = headers
+ self.stream = stream
+ self.messages: list[tuple[FormMessage, bytes]] = []
+
+ def on_field_start(self) -> None:
+ message = (FormMessage.FIELD_START, b"")
+ self.messages.append(message)
+
+ def on_field_name(self, data: bytes, start: int, end: int) -> None:
+ message = (FormMessage.FIELD_NAME, data[start:end])
+ self.messages.append(message)
+
+ def on_field_data(self, data: bytes, start: int, end: int) -> None:
+ message = (FormMessage.FIELD_DATA, data[start:end])
+ self.messages.append(message)
+
+ def on_field_end(self) -> None:
+ message = (FormMessage.FIELD_END, b"")
+ self.messages.append(message)
+
+ def on_end(self) -> None:
+ message = (FormMessage.END, b"")
+ self.messages.append(message)
+
+ async def parse(self) -> FormData:
+ # Callbacks dictionary.
+ callbacks: QuerystringCallbacks = {
+ "on_field_start": self.on_field_start,
+ "on_field_name": self.on_field_name,
+ "on_field_data": self.on_field_data,
+ "on_field_end": self.on_field_end,
+ "on_end": self.on_end,
+ }
+
+ # Create the parser.
+ parser = multipart.QuerystringParser(callbacks)
+ field_name = b""
+ field_value = b""
+
+ items: list[tuple[str, str | UploadFile]] = []
+
+ # Feed the parser with data from the request.
+ async for chunk in self.stream:
+ if chunk:
+ parser.write(chunk)
+ else:
+ parser.finalize()
+ messages = list(self.messages)
+ self.messages.clear()
+ for message_type, message_bytes in messages:
+ if message_type == FormMessage.FIELD_START:
+ field_name = b""
+ field_value = b""
+ elif message_type == FormMessage.FIELD_NAME:
+ field_name += message_bytes
+ elif message_type == FormMessage.FIELD_DATA:
+ field_value += message_bytes
+ elif message_type == FormMessage.FIELD_END:
+ name = unquote_plus(field_name.decode("latin-1"))
+ value = unquote_plus(field_value.decode("latin-1"))
+ items.append((name, value))
+
+ return FormData(items)
+
+
+class MultiPartParser:
+ spool_max_size = 1024 * 1024 # 1MB
+ """The maximum size of the spooled temporary file used to store file data."""
+ max_part_size = 1024 * 1024 # 1MB
+ """The maximum size of a part in the multipart request."""
+
+ def __init__(
+ self,
+ headers: Headers,
+ stream: typing.AsyncGenerator[bytes, None],
+ *,
+ max_files: int | float = 1000,
+ max_fields: int | float = 1000,
+ max_part_size: int = 1024 * 1024, # 1MB
+ ) -> None:
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
+ self.headers = headers
+ self.stream = stream
+ self.max_files = max_files
+ self.max_fields = max_fields
+ self.items: list[tuple[str, str | UploadFile]] = []
+ self._current_files = 0
+ self._current_fields = 0
+ self._current_partial_header_name: bytes = b""
+ self._current_partial_header_value: bytes = b""
+ self._current_part = MultipartPart()
+ self._charset = ""
+ self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
+ self._file_parts_to_finish: list[MultipartPart] = []
+ self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
+ self.max_part_size = max_part_size
+
+ def on_part_begin(self) -> None:
+ self._current_part = MultipartPart()
+
+ def on_part_data(self, data: bytes, start: int, end: int) -> None:
+ message_bytes = data[start:end]
+ if self._current_part.file is None:
+ if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
+ raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
+ self._current_part.data.extend(message_bytes)
+ else:
+ self._file_parts_to_write.append((self._current_part, message_bytes))
+
+ def on_part_end(self) -> None:
+ if self._current_part.file is None:
+ self.items.append(
+ (
+ self._current_part.field_name,
+ _user_safe_decode(self._current_part.data, self._charset),
+ )
+ )
+ else:
+ self._file_parts_to_finish.append(self._current_part)
+ # The file can be added to the items right now even though it's not
+ # finished yet, because it will be finished in the `parse()` method, before
+ # self.items is used in the return value.
+ self.items.append((self._current_part.field_name, self._current_part.file))
+
+ def on_header_field(self, data: bytes, start: int, end: int) -> None:
+ self._current_partial_header_name += data[start:end]
+
+ def on_header_value(self, data: bytes, start: int, end: int) -> None:
+ self._current_partial_header_value += data[start:end]
+
+ def on_header_end(self) -> None:
+ field = self._current_partial_header_name.lower()
+ if field == b"content-disposition":
+ self._current_part.content_disposition = self._current_partial_header_value
+ self._current_part.item_headers.append((field, self._current_partial_header_value))
+ self._current_partial_header_name = b""
+ self._current_partial_header_value = b""
+
+ def on_headers_finished(self) -> None:
+ disposition, options = parse_options_header(self._current_part.content_disposition)
+ try:
+ self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
+ except KeyError:
+ raise MultiPartException('The Content-Disposition header field "name" must be provided.')
+ if b"filename" in options:
+ self._current_files += 1
+ if self._current_files > self.max_files:
+ raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
+ filename = _user_safe_decode(options[b"filename"], self._charset)
+ tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
+ self._files_to_close_on_error.append(tempfile)
+ self._current_part.file = UploadFile(
+ file=tempfile, # type: ignore[arg-type]
+ size=0,
+ filename=filename,
+ headers=Headers(raw=self._current_part.item_headers),
+ )
+ else:
+ self._current_fields += 1
+ if self._current_fields > self.max_fields:
+ raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
+ self._current_part.file = None
+
+ def on_end(self) -> None:
+ pass
+
+ async def parse(self) -> FormData:
+ # Parse the Content-Type header to get the multipart boundary.
+ _, params = parse_options_header(self.headers["Content-Type"])
+ charset = params.get(b"charset", "utf-8")
+ if isinstance(charset, bytes):
+ charset = charset.decode("latin-1")
+ self._charset = charset
+ try:
+ boundary = params[b"boundary"]
+ except KeyError:
+ raise MultiPartException("Missing boundary in multipart.")
+
+ # Callbacks dictionary.
+ callbacks: MultipartCallbacks = {
+ "on_part_begin": self.on_part_begin,
+ "on_part_data": self.on_part_data,
+ "on_part_end": self.on_part_end,
+ "on_header_field": self.on_header_field,
+ "on_header_value": self.on_header_value,
+ "on_header_end": self.on_header_end,
+ "on_headers_finished": self.on_headers_finished,
+ "on_end": self.on_end,
+ }
+
+ # Create the parser.
+ parser = multipart.MultipartParser(boundary, callbacks)
+ try:
+ # Feed the parser with data from the request.
+ async for chunk in self.stream:
+ parser.write(chunk)
+ # Write file data, it needs to use await with the UploadFile methods
+ # that call the corresponding file methods *in a threadpool*,
+ # otherwise, if they were called directly in the callback methods above
+ # (regular, non-async functions), that would block the event loop in
+ # the main thread.
+ for part, data in self._file_parts_to_write:
+ assert part.file # for type checkers
+ await part.file.write(data)
+ for part in self._file_parts_to_finish:
+ assert part.file # for type checkers
+ await part.file.seek(0)
+ self._file_parts_to_write.clear()
+ self._file_parts_to_finish.clear()
+ except MultiPartException as exc:
+ # Close all the files if there was an error.
+ for file in self._files_to_close_on_error:
+ file.close()
+ raise exc
+
+ parser.finalize()
+ return FormData(self.items)