diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/fileshare/_download.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/storage/fileshare/_download.py | 524 |
1 files changed, 524 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_download.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_download.py new file mode 100644 index 00000000..a37bca9a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_download.py @@ -0,0 +1,524 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +import threading +import warnings +from io import BytesIO +from typing import ( + Any, Callable, Generator, IO, Iterator, Optional, Tuple, + TYPE_CHECKING +) + +from azure.core.exceptions import HttpResponseError, ResourceModifiedError +from azure.core.tracing.common import with_current_context +from ._shared.request_handlers import validate_and_format_range_headers +from ._shared.response_handlers import parse_length_from_content_range, process_storage_error + +if TYPE_CHECKING: + from ._generated.operations import FileOperations + from ._models import FileProperties + from ._shared.models import StorageConfiguration + + +def process_content(data: Any) -> bytes: + if data is None: + raise ValueError("Response cannot be None.") + + try: + return b"".join(list(data)) + except Exception as error: + raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error + + +class _ChunkDownloader(object): # pylint: disable=too-many-instance-attributes + def __init__( + self, client: "FileOperations", + total_size: int, + chunk_size: int, + current_progress: int, + start_range: int, + end_range: int, + validate_content: bool, + etag: str, + stream: Any = None, + parallel: Optional[int] = None, + progress_hook: Optional[Callable[[int, Optional[int]], None]] = None, + **kwargs: Any + ) -> None: + self.client = client + self.etag = etag + # Information on the download range/chunk size + self.chunk_size = chunk_size + self.total_size = total_size + self.start_index = start_range + self.end_index = end_range + + # The destination that we will write to + self.stream = stream + self.stream_lock = threading.Lock() if parallel else None + self.progress_lock = threading.Lock() if parallel else None + self.progress_hook = progress_hook + + # For a parallel download, the stream is always seekable, so we note down the current position + # in order to seek to the right place when out-of-order chunks come in + self.stream_start = stream.tell() if parallel else 0 + + # Download progress so far + self.progress_total = current_progress + + # Parameters for each get operation + self.validate_content = validate_content + self.request_options = kwargs + + def _calculate_range(self, chunk_start: int) -> Tuple[int, int]: + if chunk_start + self.chunk_size > self.end_index: + chunk_end = self.end_index + else: + chunk_end = chunk_start + self.chunk_size + return chunk_start, chunk_end + + def get_chunk_offsets(self) -> Generator[int, None, None]: + index = self.start_index + while index < self.end_index: + yield index + index += self.chunk_size + + def process_chunk(self, chunk_start: int) -> None: + chunk_start, chunk_end = self._calculate_range(chunk_start) + chunk_data = self._download_chunk(chunk_start, chunk_end - 1) + length = chunk_end - chunk_start + if length > 0: + self._write_to_stream(chunk_data, chunk_start) + self._update_progress(length) + + def yield_chunk(self, chunk_start: int) -> bytes: + chunk_start, chunk_end = self._calculate_range(chunk_start) + return self._download_chunk(chunk_start, chunk_end - 1) + + def _update_progress(self, length: int) -> None: + if self.progress_lock: + with self.progress_lock: # pylint: disable=not-context-manager + self.progress_total += length + else: + self.progress_total += length + + if self.progress_hook: + self.progress_hook(self.progress_total, self.total_size) + + def _write_to_stream(self, chunk_data: bytes, chunk_start: int) -> None: + if self.stream_lock: + with self.stream_lock: # pylint: disable=not-context-manager + self.stream.seek(self.stream_start + (chunk_start - self.start_index)) + self.stream.write(chunk_data) + else: + self.stream.write(chunk_data) + + def _download_chunk(self, chunk_start: int, chunk_end: int) -> bytes: + range_header, range_validation = validate_and_format_range_headers( + chunk_start, chunk_end, check_content_md5=self.validate_content + ) + + try: + response: Any = None + _, response = self.client.download( + range=range_header, + range_get_content_md5=range_validation, + validate_content=self.validate_content, + data_stream_total=self.total_size, + download_stream_current=self.progress_total, + **self.request_options + ) + if response.properties.etag != self.etag: + raise ResourceModifiedError(message="The file has been modified while downloading.") + + except HttpResponseError as error: + process_storage_error(error) + + chunk_data = process_content(response) + return chunk_data + + +class _ChunkIterator(object): + """Iterator for chunks in file download stream.""" + + def __init__(self, size: int, content: bytes, downloader: Optional[_ChunkDownloader], chunk_size: int) -> None: + self.size = size + self._chunk_size = chunk_size + self._current_content = content + self._iter_downloader = downloader + self._iter_chunks: Optional[Generator[int, None, None]] = None + self._complete = size == 0 + + def __len__(self) -> int: + return self.size + + def __iter__(self) -> Iterator[bytes]: + return self + + def __next__(self) -> bytes: + if self._complete: + raise StopIteration("Download complete") + if not self._iter_downloader: + # cut the data obtained from initial GET into chunks + if len(self._current_content) > self._chunk_size: + return self._get_chunk_data() + self._complete = True + return self._current_content + + if not self._iter_chunks: + self._iter_chunks = self._iter_downloader.get_chunk_offsets() + + # initial GET result still has more than _chunk_size bytes of data + if len(self._current_content) >= self._chunk_size: + return self._get_chunk_data() + + try: + chunk = next(self._iter_chunks) + self._current_content += self._iter_downloader.yield_chunk(chunk) + except StopIteration as e: + self._complete = True + if self._current_content: + return self._current_content + raise e + + return self._get_chunk_data() + + next = __next__ # Python 2 compatibility. + + def _get_chunk_data(self) -> bytes: + chunk_data = self._current_content[: self._chunk_size] + self._current_content = self._current_content[self._chunk_size:] + return chunk_data + + +class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes + """A streaming object to download from Azure Storage.""" + + name: str + """The name of the file being downloaded.""" + path: str + """The full path of the file.""" + share: str + """The name of the share where the file is.""" + properties: "FileProperties" + """The properties of the file being downloaded. If only a range of the data is being + downloaded, this will be reflected in the properties.""" + size: int + """The size of the total data in the stream. This will be the byte range if specified, + otherwise the total size of the file.""" + + def __init__( + self, client: "FileOperations" = None, # type: ignore [assignment] + config: "StorageConfiguration" = None, # type: ignore [assignment] + start_range: Optional[int] = None, + end_range: Optional[int] = None, + validate_content: bool = None, # type: ignore [assignment] + max_concurrency: int = 1, + name: str = None, # type: ignore [assignment] + path: str = None, # type: ignore [assignment] + share: str = None, # type: ignore [assignment] + encoding: Optional[str] = None, + **kwargs: Any + ) -> None: + self.name = name + self.path = path + self.share = share + self.size = 0 + + self._client = client + self._config = config + self._start_range = start_range + self._end_range = end_range + self._max_concurrency = max_concurrency + self._encoding = encoding + self._validate_content = validate_content + self._progress_hook = kwargs.pop('progress_hook', None) + self._request_options = kwargs + self._location_mode = None + self._download_complete = False + self._current_content = b"" + self._file_size = 0 + self._response = None + self._etag = "" + + # The service only provides transactional MD5s for chunks under 4MB. + # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # chunk so a transactional MD5 can be retrieved. + self._first_get_size = ( + self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + ) + initial_request_start = self._start_range or 0 + if self._end_range is not None and self._end_range - initial_request_start < self._first_get_size: + initial_request_end = self._end_range + else: + initial_request_end = initial_request_start + self._first_get_size - 1 + + self._initial_range = (initial_request_start, initial_request_end) + + self._response = self._initial_request() + self.properties = self._response.properties + self.properties.name = self.name + self.properties.path = self.path + self.properties.share = self.share + + # Set the content length to the download size instead of the size of + # the last range + self.properties.size = self.size + + # Overwrite the content range to the user requested range + self.properties.content_range = f"bytes {self._start_range}-{self._end_range}/{self._file_size}" + + # Overwrite the content MD5 as it is the MD5 for the last range instead + # of the stored MD5 + # TODO: Set to the stored MD5 when the service returns this + self.properties.content_md5 = None # type: ignore [attr-defined] + + if self.size == 0: + self._current_content = b"" + else: + self._current_content = process_content(self._response) + + def __len__(self) -> int: + return self.size + + def _initial_request(self): + range_header, range_validation = validate_and_format_range_headers( + self._initial_range[0], + self._initial_range[1], + start_range_required=False, + end_range_required=False, + check_content_md5=self._validate_content + ) + + try: + location_mode, response = self._client.download( + range=range_header, + range_get_content_md5=range_validation, + validate_content=self._validate_content, + data_stream_total=None, + download_stream_current=0, + **self._request_options + ) + + # Check the location we read from to ensure we use the same one + # for subsequent requests. + self._location_mode = location_mode + + # Parse the total file size and adjust the download size if ranges + # were specified + self._file_size = parse_length_from_content_range(response.properties.content_range) + if self._file_size is None: + raise ValueError("Required Content-Range response header is missing or malformed.") + + if self._end_range is not None: + # Use the end range index unless it is over the end of the file + self.size = min(self._file_size, self._end_range - self._start_range + 1) + elif self._start_range is not None: + self.size = self._file_size - self._start_range + else: + self.size = self._file_size + + except HttpResponseError as error: + if self._start_range is None and error.response and error.response.status_code == 416: + # Get range will fail on an empty file. If the user did not + # request a range, do a regular get request in order to get + # any properties. + try: + _, response = self._client.download( + validate_content=self._validate_content, + data_stream_total=0, + download_stream_current=0, + **self._request_options + ) + except HttpResponseError as e: + process_storage_error(e) + + # Set the download size to empty + self.size = 0 + self._file_size = 0 + else: + process_storage_error(error) + + # If the file is small, the download is complete at this point. + # If file size is large, download the rest of the file in chunks. + if response.properties.size == self.size: + self._download_complete = True + self._etag = response.properties.etag + return response + + def chunks(self) -> Iterator[bytes]: + """ + Iterate over chunks in the download stream. + + :return: An iterator of the chunks in the download stream. + :rtype: Iterator[bytes] + """ + if self.size == 0 or self._download_complete: + iter_downloader = None + else: + data_end = self._file_size + if self._end_range is not None: + # Use the end range index unless it is over the end of the file + data_end = min(self._file_size, self._end_range + 1) + iter_downloader = _ChunkDownloader( + client=self._client, + total_size=self.size, + chunk_size=self._config.max_chunk_get_size, + current_progress=self._first_get_size, + start_range=self._initial_range[1] + 1, # start where the first download ended + end_range=data_end, + stream=None, + parallel=False, + validate_content=self._validate_content, + use_location=self._location_mode, + etag=self._etag, + **self._request_options + ) + return _ChunkIterator( + size=self.size, + content=self._current_content, + downloader=iter_downloader, + chunk_size=self._config.max_chunk_get_size) + + def readall(self) -> bytes: + """Download the contents of this file. + + This operation is blocking until all data is downloaded. + :return: The entire blob content as bytes. + :rtype: bytes + """ + stream = BytesIO() + self.readinto(stream) + data = stream.getvalue() + if self._encoding: + return data.decode(self._encoding) # type: ignore [return-value] + return data + + def content_as_bytes(self, max_concurrency=1): + """DEPRECATED: Download the contents of this file. + + This operation is blocking until all data is downloaded. + + This method is deprecated, use func:`readall` instead. + + :param int max_concurrency: + The number of parallel connections with which to download. + :return: The contents of the file as bytes. + :rtype: bytes + """ + warnings.warn( + "content_as_bytes is deprecated, use readall instead", + DeprecationWarning + ) + self._max_concurrency = max_concurrency + return self.readall() + + def content_as_text(self, max_concurrency=1, encoding="UTF-8"): + """DEPRECATED: Download the contents of this file, and decode as text. + + This operation is blocking until all data is downloaded. + + This method is deprecated, use func:`readall` instead. + + :param int max_concurrency: + The number of parallel connections with which to download. + :param str encoding: + Test encoding to decode the downloaded bytes. Default is UTF-8. + :return: The contents of the file as a str. + :rtype: str + """ + warnings.warn( + "content_as_text is deprecated, use readall instead", + DeprecationWarning + ) + self._max_concurrency = max_concurrency + self._encoding = encoding + return self.readall() + + def readinto(self, stream: IO[bytes]) -> int: + """Download the contents of this file to a stream. + + :param IO[bytes] stream: + The stream to download to. This can be an open file-handle, + or any writable stream. The stream must be seekable if the download + uses more than one parallel connection. + :returns: The number of bytes read. + :rtype: int + """ + # The stream must be seekable if parallel download is required + parallel = self._max_concurrency > 1 + if parallel: + error_message = "Target stream handle must be seekable." + if sys.version_info >= (3,) and not stream.seekable(): + raise ValueError(error_message) + + try: + stream.seek(stream.tell()) + except (NotImplementedError, AttributeError) as exc: + raise ValueError(error_message) from exc + + # Write the content to the user stream + stream.write(self._current_content) + if self._progress_hook: + self._progress_hook(len(self._current_content), self.size) + + if self._download_complete: + return self.size + + data_end = self._file_size + if self._end_range is not None: + # Use the length unless it is over the end of the file + data_end = min(self._file_size, self._end_range + 1) + + downloader = _ChunkDownloader( + client=self._client, + total_size=self.size, + chunk_size=self._config.max_chunk_get_size, + current_progress=self._first_get_size, + start_range=self._initial_range[1] + 1, # Start where the first download ended + end_range=data_end, + stream=stream, + parallel=parallel, + validate_content=self._validate_content, + use_location=self._location_mode, + progress_hook=self._progress_hook, + etag=self._etag, + **self._request_options + ) + if parallel: + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(self._max_concurrency) as executor: + list(executor.map( + with_current_context(downloader.process_chunk), + downloader.get_chunk_offsets() + )) + else: + for chunk in downloader.get_chunk_offsets(): + downloader.process_chunk(chunk) + return self.size + + def download_to_stream(self, stream, max_concurrency=1): + """DEPRECATED: Download the contents of this file to a stream. + + This method is deprecated, use func:`readinto` instead. + + :param IO stream: + The stream to download to. This can be an open file-handle, + or any writable stream. The stream must be seekable if the download + uses more than one parallel connection. + :param int max_concurrency: + The number of parallel connections with which to download. + :returns: The properties of the downloaded file. + :rtype: Any + """ + warnings.warn( + "download_to_stream is deprecated, use readinto instead", + DeprecationWarning + ) + self._max_concurrency = max_concurrency + self.readinto(stream) + return self.properties |