diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared')
14 files changed, 4460 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py new file mode 100644 index 00000000..a8b1a27d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py @@ -0,0 +1,54 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import base64 +import hashlib +import hmac + +try: + from urllib.parse import quote, unquote +except ImportError: + from urllib2 import quote, unquote # type: ignore + + +def url_quote(url): + return quote(url) + + +def url_unquote(url): + return unquote(url) + + +def encode_base64(data): + if isinstance(data, str): + data = data.encode('utf-8') + encoded = base64.b64encode(data) + return encoded.decode('utf-8') + + +def decode_base64_to_bytes(data): + if isinstance(data, str): + data = data.encode('utf-8') + return base64.b64decode(data) + + +def decode_base64_to_text(data): + decoded_bytes = decode_base64_to_bytes(data) + return decoded_bytes.decode('utf-8') + + +def sign_string(key, string_to_sign, key_is_base64=True): + if key_is_base64: + key = decode_base64_to_bytes(key) + else: + if isinstance(key, str): + key = key.encode('utf-8') + if isinstance(string_to_sign, str): + string_to_sign = string_to_sign.encode('utf-8') + signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) + digest = signed_hmac_sha256.digest() + encoded_digest = encode_base64(digest) + return encoded_digest diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py new file mode 100644 index 00000000..44c563d8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py @@ -0,0 +1,244 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import re +from typing import List, Tuple +from urllib.parse import unquote, urlparse +from functools import cmp_to_key + +try: + from yarl import URL +except ImportError: + pass + +try: + from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import +except ImportError: + AioHttpTransport = None + +from azure.core.exceptions import ClientAuthenticationError +from azure.core.pipeline.policies import SansIOHTTPPolicy + +from . import sign_string + +logger = logging.getLogger(__name__) + +table_lv0 = [ + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x71c, 0x0, 0x71f, 0x721, 0x723, 0x725, + 0x0, 0x0, 0x0, 0x72d, 0x803, 0x0, 0x0, 0x733, 0x0, 0xd03, 0xd1a, 0xd1c, 0xd1e, + 0xd20, 0xd22, 0xd24, 0xd26, 0xd28, 0xd2a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25, 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51, + 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99, 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9, + 0x0, 0x0, 0x0, 0x743, 0x744, 0x748, 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25, + 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51, 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99, + 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9, 0x0, 0x74c, 0x0, 0x750, 0x0, +] + +table_lv4 = [ + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8012, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8212, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, +] + +def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements + tables = [table_lv0, table_lv4] + curr_level, i, j, n = 0, 0, 0, len(tables) + lhs_len = len(lhs) + rhs_len = len(rhs) + while curr_level < n: + if curr_level == (n - 1) and i != j: + if i > j: + return -1 + if i < j: + return 1 + return 0 + + w1 = tables[curr_level][ord(lhs[i])] if i < lhs_len else 0x1 + w2 = tables[curr_level][ord(rhs[j])] if j < rhs_len else 0x1 + + if w1 == 0x1 and w2 == 0x1: + i = 0 + j = 0 + curr_level += 1 + elif w1 == w2: + i += 1 + j += 1 + elif w1 == 0: + i += 1 + elif w2 == 0: + j += 1 + else: + if w1 < w2: + return -1 + if w1 > w2: + return 1 + return 0 + return 0 + + +# wraps a given exception with the desired exception type +def _wrap_exception(ex, desired_type): + msg = "" + if ex.args: + msg = ex.args[0] + return desired_type(msg) + +# This method attempts to emulate the sorting done by the service +def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + + # Build dict of tuples and list of keys + header_dict = {} + header_keys = [] + for k, v in input_headers: + header_dict[k] = v + header_keys.append(k) + + try: + header_keys = sorted(header_keys, key=cmp_to_key(compare)) + except ValueError as exc: + raise ValueError("Illegal character encountered when sorting headers.") from exc + + # Build list of sorted tuples + sorted_headers = [] + for key in header_keys: + sorted_headers.append((key, header_dict.pop(key))) + return sorted_headers + + +class AzureSigningError(ClientAuthenticationError): + """ + Represents a fatal error when attempting to sign a request. + In general, the cause of this exception is user error. For example, the given account key is not valid. + Please visit https://learn.microsoft.com/azure/storage/common/storage-create-storage-account for more info. + """ + + +class SharedKeyCredentialPolicy(SansIOHTTPPolicy): + + def __init__(self, account_name, account_key): + self.account_name = account_name + self.account_key = account_key + super(SharedKeyCredentialPolicy, self).__init__() + + @staticmethod + def _get_headers(request, headers_to_sign): + headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value) + if 'content-length' in headers and headers['content-length'] == '0': + del headers['content-length'] + return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n' + + @staticmethod + def _get_verb(request): + return request.http_request.method + '\n' + + def _get_canonicalized_resource(self, request): + uri_path = urlparse(request.http_request.url).path + try: + if isinstance(request.context.transport, AioHttpTransport) or \ + isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \ + isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None), + AioHttpTransport): + uri_path = URL(uri_path) + return '/' + self.account_name + str(uri_path) + except TypeError: + pass + return '/' + self.account_name + uri_path + + @staticmethod + def _get_canonicalized_headers(request): + string_to_sign = '' + x_ms_headers = [] + for name, value in request.http_request.headers.items(): + if name.startswith('x-ms-'): + x_ms_headers.append((name.lower(), value)) + x_ms_headers = _storage_header_sort(x_ms_headers) + for name, value in x_ms_headers: + if value is not None: + string_to_sign += ''.join([name, ':', value, '\n']) + return string_to_sign + + @staticmethod + def _get_canonicalized_resource_query(request): + sorted_queries = list(request.http_request.query.items()) + sorted_queries.sort() + + string_to_sign = '' + for name, value in sorted_queries: + if value is not None: + string_to_sign += '\n' + name.lower() + ':' + unquote(value) + + return string_to_sign + + def _add_authorization_header(self, request, string_to_sign): + try: + signature = sign_string(self.account_key, string_to_sign) + auth_string = 'SharedKey ' + self.account_name + ':' + signature + request.http_request.headers['Authorization'] = auth_string + except Exception as ex: + # Wrap any error that occurred as signing error + # Doing so will clarify/locate the source of problem + raise _wrap_exception(ex, AzureSigningError) from ex + + def on_request(self, request): + string_to_sign = \ + self._get_verb(request) + \ + self._get_headers( + request, + [ + 'content-encoding', 'content-language', 'content-length', + 'content-md5', 'content-type', 'date', 'if-modified-since', + 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range' + ] + ) + \ + self._get_canonicalized_headers(request) + \ + self._get_canonicalized_resource(request) + \ + self._get_canonicalized_resource_query(request) + + self._add_authorization_header(request, string_to_sign) + # logger.debug("String_to_sign=%s", string_to_sign) + + +class StorageHttpChallenge(object): + def __init__(self, challenge): + """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """ + if not challenge: + raise ValueError("Challenge cannot be empty") + + self._parameters = {} + self.scheme, trimmed_challenge = challenge.strip().split(" ", 1) + + # name=value pairs either comma or space separated with values possibly being + # enclosed in quotes + for item in re.split('[, ]', trimmed_challenge): + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value + + # Extract and verify required parameters + self.authorization_uri = self._parameters.get('authorization_uri') + if not self.authorization_uri: + raise ValueError("Authorization Uri not found") + + self.resource_id = self._parameters.get('resource_id') + if not self.resource_id: + raise ValueError("Resource id not found") + + uri_path = urlparse(self.authorization_uri).path.lstrip("/") + self.tenant_id = uri_path.split("/")[0] + + def get_value(self, key): + return self._parameters.get(key) diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py new file mode 100644 index 00000000..9dc8d2ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py @@ -0,0 +1,458 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +import uuid +from typing import ( + Any, + cast, + Dict, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) +from urllib.parse import parse_qs, quote + +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential +from azure.core.exceptions import HttpResponseError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.policies import ( + AzureSasCredentialPolicy, + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + ProxyPolicy, + RedirectPolicy, + UserAgentPolicy, +) + +from .authentication import SharedKeyCredentialPolicy +from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE +from .models import LocationMode, StorageConfiguration +from .policies import ( + ExponentialRetry, + QueueMessagePolicy, + StorageBearerTokenCredentialPolicy, + StorageContentValidation, + StorageHeadersPolicy, + StorageHosts, + StorageLoggingPolicy, + StorageRequestHook, + StorageResponseHook, +) +from .request_handlers import serialize_batch_body, _get_batch_request_delimiter +from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants +from .._version import VERSION +from .._shared_access_signature import _is_credential_sastoken + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 + +_LOGGER = logging.getLogger(__name__) +_SERVICE_PARAMS = { + "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"}, + "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"}, + "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"}, + "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"}, +} + + +class StorageAccountHostsMixin(object): + _client: Any + def __init__( + self, + parsed_url: Any, + service: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: + self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) + self._hosts = kwargs.get("_hosts") + self.scheme = parsed_url.scheme + self._is_localhost = False + + if service not in ["blob", "queue", "file-share", "dfs"]: + raise ValueError(f"Invalid service: {service}") + service_name = service.split('-')[0] + account = parsed_url.netloc.split(f".{service_name}.core.") + + self.account_name = account[0] if len(account) > 1 else None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self._is_localhost = True + self.account_name = parsed_url.path.strip("/") + + self.credential = _format_shared_key_credential(self.account_name, credential) + if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): + raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None + if hasattr(self.credential, "account_name"): + self.account_name = self.credential.account_name + secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}" + + if not self._hosts: + if len(account) > 1: + secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") + if kwargs.get("secondary_hostname"): + secondary_hostname = kwargs["secondary_hostname"] + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} + + self._sdk_moniker = f"storage-{service}/{VERSION}" + self._config, self._pipeline = self._create_pipeline(self.credential, sdk_moniker=self._sdk_moniker, **kwargs) + + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + """ This method is to close the sockets opened by the client. + It need not be used when using with a context manager. + """ + self._client.close() + + @property + def url(self): + """The full endpoint URL to this entity, including SAS token if used. + + This could be either the primary endpoint, + or the secondary endpoint depending on the current :func:`location_mode`. + :returns: The full endpoint URL to this entity, including SAS token if used. + :rtype: str + """ + return self._format_url(self._hosts[self._location_mode]) + + @property + def primary_endpoint(self): + """The full primary endpoint URL. + + :rtype: str + """ + return self._format_url(self._hosts[LocationMode.PRIMARY]) + + @property + def primary_hostname(self): + """The hostname of the primary endpoint. + + :rtype: str + """ + return self._hosts[LocationMode.PRIMARY] + + @property + def secondary_endpoint(self): + """The full secondary endpoint URL if configured. + + If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + + :rtype: str + :raise ValueError: + """ + if not self._hosts[LocationMode.SECONDARY]: + raise ValueError("No secondary host configured.") + return self._format_url(self._hosts[LocationMode.SECONDARY]) + + @property + def secondary_hostname(self): + """The hostname of the secondary endpoint. + + If not available this will be None. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + + :rtype: Optional[str] + """ + return self._hosts[LocationMode.SECONDARY] + + @property + def location_mode(self): + """The location mode that the client is currently using. + + By default this will be "primary". Options include "primary" and "secondary". + + :rtype: str + """ + + return self._location_mode + + @location_mode.setter + def location_mode(self, value): + if self._hosts.get(value): + self._location_mode = value + self._client._config.url = self.url # pylint: disable=protected-access + else: + raise ValueError(f"No host URL for location mode: {value}") + + @property + def api_version(self): + """The version of the Storage API used for requests. + + :rtype: str + """ + return self._client._config.version # pylint: disable=protected-access + + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + query_str = "?" + if snapshot: + query_str += f"snapshot={snapshot}&" + if share_snapshot: + query_str += f"sharesnapshot={share_snapshot}&" + if sas_token and isinstance(credential, AzureSasCredential): + raise ValueError( + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + if _is_credential_sastoken(credential): + credential = cast(str, credential) + query_str += credential.lstrip("?") + credential = None + elif sas_token: + query_str += sas_token + return query_str.rstrip("?&"), credential + + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, Pipeline]: + self._credential_policy: Any = None + if hasattr(credential, "get_token"): + if kwargs.get('audience'): + audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + else: + audience = STORAGE_OAUTH_SCOPE + self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) + elif isinstance(credential, SharedKeyCredentialPolicy): + self._credential_policy = credential + elif isinstance(credential, AzureSasCredential): + self._credential_policy = AzureSasCredentialPolicy(credential) + elif credential is not None: + raise TypeError(f"Unsupported credential: {type(credential)}") + + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") + kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) + kwargs.setdefault("read_timeout", READ_TIMEOUT) + if not transport: + transport = RequestsTransport(**kwargs) + policies = [ + QueueMessagePolicy(), + config.proxy_policy, + config.user_agent_policy, + StorageContentValidation(), + ContentDecodePolicy(response_encoding="utf-8"), + RedirectPolicy(**kwargs), + StorageHosts(hosts=self._hosts, **kwargs), + config.retry_policy, + config.headers_policy, + StorageRequestHook(**kwargs), + self._credential_policy, + config.logging_policy, + StorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs) + ] + if kwargs.get("_additional_pipeline_policies"): + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, Pipeline(transport, policies=policies) + + def _batch_send( + self, + *reqs: "HttpRequest", + **kwargs: Any + ) -> Iterator["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An iterator of HttpResponse objects. + :rtype: Iterator[HttpResponse] + """ + # Pop it here, so requests doesn't feel bad about additional kwarg + raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) + batch_id = str(uuid.uuid1()) + + request = self._client._client.post( # pylint: disable=protected-access + url=( + f'{self.scheme}://{self.primary_hostname}/' + f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" + f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" + ), + headers={ + 'x-ms-version': self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) + } + ) + + policies = [StorageHeadersPolicy()] + if self._credential_policy: + policies.append(self._credential_policy) + + request.set_multipart_mixed( + *reqs, + policies=policies, + enforce_https=False + ) + + Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access + body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) + request.set_bytes_body(body) + + temp = request.multipart_mixed_info + request.multipart_mixed_info = None + pipeline_response = self._pipeline.run( + request, **kwargs + ) + response = pipeline_response.http_response + request.multipart_mixed_info = temp + + try: + if response.status_code not in [202]: + raise HttpResponseError(response=response) + parts = response.parts() + if raise_on_any_failure: + parts = list(response.parts()) + if any(p for p in parts if not 200 <= p.status_code < 300): + error = PartialBatchErrorException( + message="There is a partial failure in the batch operation.", + response=response, parts=parts + ) + raise error + return iter(parts) + return parts # type: ignore [no-any-return] + except HttpResponseError as error: + process_storage_error(error) + + +class TransportWrapper(HttpTransport): + """Wrapper class that ensures that an inner client created + by a `get_client` method does not close the outer transport for the parent + when used in a context manager. + """ + def __init__(self, transport): + self._transport = transport + + def send(self, request, **kwargs): + return self._transport.send(request, **kwargs) + + def open(self): + pass + + def close(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +def _format_shared_key_credential( + account_name: Optional[str], + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long +) -> Any: + if isinstance(credential, str): + if not account_name: + raise ValueError("Unable to determine account name for shared key credential.") + credential = {"account_name": account_name, "account_key": credential} + if isinstance(credential, dict): + if "account_name" not in credential: + raise ValueError("Shared key credential missing 'account_name") + if "account_key" not in credential: + raise ValueError("Shared key credential missing 'account_key") + return SharedKeyCredentialPolicy(**credential) + if isinstance(credential, AzureNamedKeyCredential): + return SharedKeyCredentialPolicy(credential.named_key.name, credential.named_key.key) + return credential + + +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], + service: str +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + conn_str = conn_str.rstrip(";") + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) + endpoints = _SERVICE_PARAMS[service] + primary = None + secondary = None + if not credential: + try: + credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]} + except KeyError: + credential = conn_settings.get("SHAREDACCESSSIGNATURE") + if endpoints["primary"] in conn_settings: + primary = conn_settings[endpoints["primary"]] + if endpoints["secondary"] in conn_settings: + secondary = conn_settings[endpoints["secondary"]] + else: + if endpoints["secondary"] in conn_settings: + raise ValueError("Connection string specifies only secondary endpoint.") + try: + primary =( + f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" + f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + secondary = ( + f"{conn_settings['ACCOUNTNAME']}-secondary." + f"{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + except KeyError: + pass + + if not primary: + try: + primary = ( + f"https://{conn_settings['ACCOUNTNAME']}." + f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}" + ) + except KeyError as exc: + raise ValueError("Connection string missing required connection details.") from exc + if service == "dfs": + primary = primary.replace(".blob.", ".dfs.") + if secondary: + secondary = secondary.replace(".blob.", ".dfs.") + return primary, secondary, credential + + +def create_configuration(**kwargs: Any) -> StorageConfiguration: + # Backwards compatibility if someone is not passing sdk_moniker + if not kwargs.get("sdk_moniker"): + kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" + config = StorageConfiguration(**kwargs) + config.headers_policy = StorageHeadersPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(**kwargs) + config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) + config.logging_policy = StorageLoggingPolicy(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + return config + + +def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: + sas_values = QueryStringConstants.to_list() + parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} + sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values] + sas_token = None + if sas_params: + sas_token = "&".join(sas_params) + + snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") + return snapshot, sas_token diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py new file mode 100644 index 00000000..6186b29d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py @@ -0,0 +1,280 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# mypy: disable-error-code="attr-defined" + +import logging +from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union + +from azure.core.async_paging import AsyncList +from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import HttpResponseError +from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies import ( + AsyncRedirectPolicy, + AzureSasCredentialPolicy, + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, +) +from azure.core.pipeline.transport import AsyncHttpTransport + +from .authentication import SharedKeyCredentialPolicy +from .base_client import create_configuration +from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE +from .models import StorageConfiguration +from .policies import ( + QueueMessagePolicy, + StorageContentValidation, + StorageHeadersPolicy, + StorageHosts, + StorageRequestHook, +) +from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .response_handlers import PartialBatchErrorException, process_storage_error +from .._shared_access_signature import _is_credential_sastoken + +if TYPE_CHECKING: + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 +_LOGGER = logging.getLogger(__name__) + +_SERVICE_PARAMS = { + "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"}, + "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"}, + "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"}, + "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"}, +} + + +class AsyncStorageAccountHostsMixin(object): + + def __enter__(self): + raise TypeError("Async client only supports 'async with'.") + + def __exit__(self, *args): + pass + + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def __aexit__(self, *args): + await self._client.__aexit__(*args) + + async def close(self): + """ This method is to close the sockets opened by the client. + It need not be used when using with a context manager. + """ + await self._client.close() + + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + query_str = "?" + if snapshot: + query_str += f"snapshot={snapshot}&" + if share_snapshot: + query_str += f"sharesnapshot={share_snapshot}&" + if sas_token and isinstance(credential, AzureSasCredential): + raise ValueError( + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + if _is_credential_sastoken(credential): + query_str += credential.lstrip("?") # type: ignore [union-attr] + credential = None + elif sas_token: + query_str += sas_token + return query_str.rstrip("?&"), credential + + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, AsyncPipeline]: + self._credential_policy: Optional[ + Union[AsyncStorageBearerTokenCredentialPolicy, + SharedKeyCredentialPolicy, + AzureSasCredentialPolicy]] = None + if hasattr(credential, 'get_token'): + if kwargs.get('audience'): + audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + else: + audience = STORAGE_OAUTH_SCOPE + self._credential_policy = AsyncStorageBearerTokenCredentialPolicy( + cast(AsyncTokenCredential, credential), audience) + elif isinstance(credential, SharedKeyCredentialPolicy): + self._credential_policy = credential + elif isinstance(credential, AzureSasCredential): + self._credential_policy = AzureSasCredentialPolicy(credential) + elif credential is not None: + raise TypeError(f"Unsupported credential: {type(credential)}") + config = kwargs.get('_configuration') or create_configuration(**kwargs) + if kwargs.get('_pipeline'): + return config, kwargs['_pipeline'] + transport = kwargs.get('transport') + kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) + kwargs.setdefault("read_timeout", READ_TIMEOUT) + if not transport: + try: + from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import + except ImportError as exc: + raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc + transport = AioHttpTransport(**kwargs) + hosts = self._hosts + policies = [ + QueueMessagePolicy(), + config.proxy_policy, + config.user_agent_policy, + StorageContentValidation(), + ContentDecodePolicy(response_encoding="utf-8"), + AsyncRedirectPolicy(**kwargs), + StorageHosts(hosts=hosts, **kwargs), + config.retry_policy, + config.headers_policy, + StorageRequestHook(**kwargs), + self._credential_policy, + config.logging_policy, + AsyncStorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + if kwargs.get("_additional_pipeline_policies"): + policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore + config.transport = transport #type: ignore + return config, AsyncPipeline(transport, policies=policies) #type: ignore + + async def _batch_send( + self, + *reqs: "HttpRequest", + **kwargs: Any + ) -> AsyncList["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An AsyncList of HttpResponse objects. + :rtype: AsyncList[HttpResponse] + """ + # Pop it here, so requests doesn't feel bad about additional kwarg + raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) + request = self._client._client.post( # pylint: disable=protected-access + url=( + f'{self.scheme}://{self.primary_hostname}/' + f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" + f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" + ), + headers={ + 'x-ms-version': self.api_version + } + ) + + policies = [StorageHeadersPolicy()] + if self._credential_policy: + policies.append(self._credential_policy) # type: ignore + + request.set_multipart_mixed( + *reqs, + policies=policies, + enforce_https=False + ) + + pipeline_response = await self._pipeline.run( + request, **kwargs + ) + response = pipeline_response.http_response + + try: + if response.status_code not in [202]: + raise HttpResponseError(response=response) + parts = response.parts() # Return an AsyncIterator + if raise_on_any_failure: + parts_list = [] + async for part in parts: + parts_list.append(part) + if any(p for p in parts_list if not 200 <= p.status_code < 300): + error = PartialBatchErrorException( + message="There is a partial failure in the batch operation.", + response=response, parts=parts_list + ) + raise error + return AsyncList(parts_list) + return parts # type: ignore [no-any-return] + except HttpResponseError as error: + process_storage_error(error) + +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], + service: str +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + conn_str = conn_str.rstrip(";") + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) + endpoints = _SERVICE_PARAMS[service] + primary = None + secondary = None + if not credential: + try: + credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]} + except KeyError: + credential = conn_settings.get("SHAREDACCESSSIGNATURE") + if endpoints["primary"] in conn_settings: + primary = conn_settings[endpoints["primary"]] + if endpoints["secondary"] in conn_settings: + secondary = conn_settings[endpoints["secondary"]] + else: + if endpoints["secondary"] in conn_settings: + raise ValueError("Connection string specifies only secondary endpoint.") + try: + primary =( + f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" + f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + secondary = ( + f"{conn_settings['ACCOUNTNAME']}-secondary." + f"{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + except KeyError: + pass + + if not primary: + try: + primary = ( + f"https://{conn_settings['ACCOUNTNAME']}." + f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}" + ) + except KeyError as exc: + raise ValueError("Connection string missing required connection details.") from exc + if service == "dfs": + primary = primary.replace(".blob.", ".dfs.") + if secondary: + secondary = secondary.replace(".blob.", ".dfs.") + return primary, secondary, credential + +class AsyncTransportWrapper(AsyncHttpTransport): + """Wrapper class that ensures that an inner client created + by a `get_client` method does not close the outer transport for the parent + when used in a context manager. + """ + def __init__(self, async_transport): + self._transport = async_transport + + async def send(self, request, **kwargs): + return await self._transport.send(request, **kwargs) + + async def open(self): + pass + + async def close(self): + pass + + async def __aenter__(self): + pass + + async def __aexit__(self, *args): + pass diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py new file mode 100644 index 00000000..0b4b029a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from .._serialize import _SUPPORTED_API_VERSIONS + + +X_MS_VERSION = _SUPPORTED_API_VERSIONS[-1] + +# Default socket timeouts, in seconds +CONNECTION_TIMEOUT = 20 +READ_TIMEOUT = 60 + +DEFAULT_OAUTH_SCOPE = "/.default" +STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default" + +SERVICE_HOST_BASE = 'core.windows.net' diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py new file mode 100644 index 00000000..403e6b8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py @@ -0,0 +1,585 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=too-many-instance-attributes +from enum import Enum +from typing import Optional + +from azure.core import CaseInsensitiveEnumMeta +from azure.core.configuration import Configuration +from azure.core.pipeline.policies import UserAgentPolicy + + +def get_enum_value(value): + if value is None or value in ["None", ""]: + return None + try: + return value.value + except AttributeError: + return value + + +class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + + # Generic storage values + ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists" + ACCOUNT_BEING_CREATED = "AccountBeingCreated" + ACCOUNT_IS_DISABLED = "AccountIsDisabled" + AUTHENTICATION_FAILED = "AuthenticationFailed" + AUTHORIZATION_FAILURE = "AuthorizationFailure" + NO_AUTHENTICATION_INFORMATION = "NoAuthenticationInformation" + CONDITION_HEADERS_NOT_SUPPORTED = "ConditionHeadersNotSupported" + CONDITION_NOT_MET = "ConditionNotMet" + EMPTY_METADATA_KEY = "EmptyMetadataKey" + INSUFFICIENT_ACCOUNT_PERMISSIONS = "InsufficientAccountPermissions" + INTERNAL_ERROR = "InternalError" + INVALID_AUTHENTICATION_INFO = "InvalidAuthenticationInfo" + INVALID_HEADER_VALUE = "InvalidHeaderValue" + INVALID_HTTP_VERB = "InvalidHttpVerb" + INVALID_INPUT = "InvalidInput" + INVALID_MD5 = "InvalidMd5" + INVALID_METADATA = "InvalidMetadata" + INVALID_QUERY_PARAMETER_VALUE = "InvalidQueryParameterValue" + INVALID_RANGE = "InvalidRange" + INVALID_RESOURCE_NAME = "InvalidResourceName" + INVALID_URI = "InvalidUri" + INVALID_XML_DOCUMENT = "InvalidXmlDocument" + INVALID_XML_NODE_VALUE = "InvalidXmlNodeValue" + MD5_MISMATCH = "Md5Mismatch" + METADATA_TOO_LARGE = "MetadataTooLarge" + MISSING_CONTENT_LENGTH_HEADER = "MissingContentLengthHeader" + MISSING_REQUIRED_QUERY_PARAMETER = "MissingRequiredQueryParameter" + MISSING_REQUIRED_HEADER = "MissingRequiredHeader" + MISSING_REQUIRED_XML_NODE = "MissingRequiredXmlNode" + MULTIPLE_CONDITION_HEADERS_NOT_SUPPORTED = "MultipleConditionHeadersNotSupported" + OPERATION_TIMED_OUT = "OperationTimedOut" + OUT_OF_RANGE_INPUT = "OutOfRangeInput" + OUT_OF_RANGE_QUERY_PARAMETER_VALUE = "OutOfRangeQueryParameterValue" + REQUEST_BODY_TOO_LARGE = "RequestBodyTooLarge" + RESOURCE_TYPE_MISMATCH = "ResourceTypeMismatch" + REQUEST_URL_FAILED_TO_PARSE = "RequestUrlFailedToParse" + RESOURCE_ALREADY_EXISTS = "ResourceAlreadyExists" + RESOURCE_NOT_FOUND = "ResourceNotFound" + SERVER_BUSY = "ServerBusy" + UNSUPPORTED_HEADER = "UnsupportedHeader" + UNSUPPORTED_XML_NODE = "UnsupportedXmlNode" + UNSUPPORTED_QUERY_PARAMETER = "UnsupportedQueryParameter" + UNSUPPORTED_HTTP_VERB = "UnsupportedHttpVerb" + + # Blob values + APPEND_POSITION_CONDITION_NOT_MET = "AppendPositionConditionNotMet" + BLOB_ACCESS_TIER_NOT_SUPPORTED_FOR_ACCOUNT_TYPE = "BlobAccessTierNotSupportedForAccountType" + BLOB_ALREADY_EXISTS = "BlobAlreadyExists" + BLOB_NOT_FOUND = "BlobNotFound" + BLOB_OVERWRITTEN = "BlobOverwritten" + BLOB_TIER_INADEQUATE_FOR_CONTENT_LENGTH = "BlobTierInadequateForContentLength" + BLOCK_COUNT_EXCEEDS_LIMIT = "BlockCountExceedsLimit" + BLOCK_LIST_TOO_LONG = "BlockListTooLong" + CANNOT_CHANGE_TO_LOWER_TIER = "CannotChangeToLowerTier" + CANNOT_VERIFY_COPY_SOURCE = "CannotVerifyCopySource" + CONTAINER_ALREADY_EXISTS = "ContainerAlreadyExists" + CONTAINER_BEING_DELETED = "ContainerBeingDeleted" + CONTAINER_DISABLED = "ContainerDisabled" + CONTAINER_NOT_FOUND = "ContainerNotFound" + CONTENT_LENGTH_LARGER_THAN_TIER_LIMIT = "ContentLengthLargerThanTierLimit" + COPY_ACROSS_ACCOUNTS_NOT_SUPPORTED = "CopyAcrossAccountsNotSupported" + COPY_ID_MISMATCH = "CopyIdMismatch" + FEATURE_VERSION_MISMATCH = "FeatureVersionMismatch" + INCREMENTAL_COPY_BLOB_MISMATCH = "IncrementalCopyBlobMismatch" + INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed" + #: Deprecated: Please use INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED instead. + INCREMENTAL_COPY_OF_ERALIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed" + INCREMENTAL_COPY_SOURCE_MUST_BE_SNAPSHOT = "IncrementalCopySourceMustBeSnapshot" + INFINITE_LEASE_DURATION_REQUIRED = "InfiniteLeaseDurationRequired" + INVALID_BLOB_OR_BLOCK = "InvalidBlobOrBlock" + INVALID_BLOB_TIER = "InvalidBlobTier" + INVALID_BLOB_TYPE = "InvalidBlobType" + INVALID_BLOCK_ID = "InvalidBlockId" + INVALID_BLOCK_LIST = "InvalidBlockList" + INVALID_OPERATION = "InvalidOperation" + INVALID_PAGE_RANGE = "InvalidPageRange" + INVALID_SOURCE_BLOB_TYPE = "InvalidSourceBlobType" + INVALID_SOURCE_BLOB_URL = "InvalidSourceBlobUrl" + INVALID_VERSION_FOR_PAGE_BLOB_OPERATION = "InvalidVersionForPageBlobOperation" + LEASE_ALREADY_PRESENT = "LeaseAlreadyPresent" + LEASE_ALREADY_BROKEN = "LeaseAlreadyBroken" + LEASE_ID_MISMATCH_WITH_BLOB_OPERATION = "LeaseIdMismatchWithBlobOperation" + LEASE_ID_MISMATCH_WITH_CONTAINER_OPERATION = "LeaseIdMismatchWithContainerOperation" + LEASE_ID_MISMATCH_WITH_LEASE_OPERATION = "LeaseIdMismatchWithLeaseOperation" + LEASE_ID_MISSING = "LeaseIdMissing" + LEASE_IS_BREAKING_AND_CANNOT_BE_ACQUIRED = "LeaseIsBreakingAndCannotBeAcquired" + LEASE_IS_BREAKING_AND_CANNOT_BE_CHANGED = "LeaseIsBreakingAndCannotBeChanged" + LEASE_IS_BROKEN_AND_CANNOT_BE_RENEWED = "LeaseIsBrokenAndCannotBeRenewed" + LEASE_LOST = "LeaseLost" + LEASE_NOT_PRESENT_WITH_BLOB_OPERATION = "LeaseNotPresentWithBlobOperation" + LEASE_NOT_PRESENT_WITH_CONTAINER_OPERATION = "LeaseNotPresentWithContainerOperation" + LEASE_NOT_PRESENT_WITH_LEASE_OPERATION = "LeaseNotPresentWithLeaseOperation" + MAX_BLOB_SIZE_CONDITION_NOT_MET = "MaxBlobSizeConditionNotMet" + NO_PENDING_COPY_OPERATION = "NoPendingCopyOperation" + OPERATION_NOT_ALLOWED_ON_INCREMENTAL_COPY_BLOB = "OperationNotAllowedOnIncrementalCopyBlob" + PENDING_COPY_OPERATION = "PendingCopyOperation" + PREVIOUS_SNAPSHOT_CANNOT_BE_NEWER = "PreviousSnapshotCannotBeNewer" + PREVIOUS_SNAPSHOT_NOT_FOUND = "PreviousSnapshotNotFound" + PREVIOUS_SNAPSHOT_OPERATION_NOT_SUPPORTED = "PreviousSnapshotOperationNotSupported" + SEQUENCE_NUMBER_CONDITION_NOT_MET = "SequenceNumberConditionNotMet" + SEQUENCE_NUMBER_INCREMENT_TOO_LARGE = "SequenceNumberIncrementTooLarge" + SNAPSHOT_COUNT_EXCEEDED = "SnapshotCountExceeded" + SNAPSHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded" + #: Deprecated: Please use SNAPSHOT_OPERATION_RATE_EXCEEDED instead. + SNAPHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded" + SNAPSHOTS_PRESENT = "SnapshotsPresent" + SOURCE_CONDITION_NOT_MET = "SourceConditionNotMet" + SYSTEM_IN_USE = "SystemInUse" + TARGET_CONDITION_NOT_MET = "TargetConditionNotMet" + UNAUTHORIZED_BLOB_OVERWRITE = "UnauthorizedBlobOverwrite" + BLOB_BEING_REHYDRATED = "BlobBeingRehydrated" + BLOB_ARCHIVED = "BlobArchived" + BLOB_NOT_ARCHIVED = "BlobNotArchived" + + # Queue values + INVALID_MARKER = "InvalidMarker" + MESSAGE_NOT_FOUND = "MessageNotFound" + MESSAGE_TOO_LARGE = "MessageTooLarge" + POP_RECEIPT_MISMATCH = "PopReceiptMismatch" + QUEUE_ALREADY_EXISTS = "QueueAlreadyExists" + QUEUE_BEING_DELETED = "QueueBeingDeleted" + QUEUE_DISABLED = "QueueDisabled" + QUEUE_NOT_EMPTY = "QueueNotEmpty" + QUEUE_NOT_FOUND = "QueueNotFound" + + # File values + CANNOT_DELETE_FILE_OR_DIRECTORY = "CannotDeleteFileOrDirectory" + CLIENT_CACHE_FLUSH_DELAY = "ClientCacheFlushDelay" + DELETE_PENDING = "DeletePending" + DIRECTORY_NOT_EMPTY = "DirectoryNotEmpty" + FILE_LOCK_CONFLICT = "FileLockConflict" + FILE_SHARE_PROVISIONED_BANDWIDTH_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedBandwidthDowngradeNotAllowed" + FILE_SHARE_PROVISIONED_IOPS_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedIopsDowngradeNotAllowed" + INVALID_FILE_OR_DIRECTORY_PATH_NAME = "InvalidFileOrDirectoryPathName" + PARENT_NOT_FOUND = "ParentNotFound" + READ_ONLY_ATTRIBUTE = "ReadOnlyAttribute" + SHARE_ALREADY_EXISTS = "ShareAlreadyExists" + SHARE_BEING_DELETED = "ShareBeingDeleted" + SHARE_DISABLED = "ShareDisabled" + SHARE_NOT_FOUND = "ShareNotFound" + SHARING_VIOLATION = "SharingViolation" + SHARE_SNAPSHOT_IN_PROGRESS = "ShareSnapshotInProgress" + SHARE_SNAPSHOT_COUNT_EXCEEDED = "ShareSnapshotCountExceeded" + SHARE_SNAPSHOT_OPERATION_NOT_SUPPORTED = "ShareSnapshotOperationNotSupported" + SHARE_HAS_SNAPSHOTS = "ShareHasSnapshots" + CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed" + + # DataLake values + CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero' + PATH_ALREADY_EXISTS = 'PathAlreadyExists' + INVALID_FLUSH_POSITION = 'InvalidFlushPosition' + INVALID_PROPERTY_NAME = 'InvalidPropertyName' + INVALID_SOURCE_URI = 'InvalidSourceUri' + UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion' + FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound' + PATH_NOT_FOUND = 'PathNotFound' + RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound' + SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound' + DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted' + FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists' + FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted' + INVALID_DESTINATION_PATH = 'InvalidDestinationPath' + INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath' + INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType' + LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken' + LEASE_NAME_MISMATCH = 'LeaseNameMismatch' + PATH_CONFLICT = 'PathConflict' + SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted' + + +class DictMixin(object): + + def __setitem__(self, key, item): + self.__dict__[key] = item + + def __getitem__(self, key): + return self.__dict__[key] + + def __repr__(self): + return str(self) + + def __len__(self): + return len(self.keys()) + + def __delitem__(self, key): + self.__dict__[key] = None + + # Compare objects by comparing all attributes. + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + # Compare objects by comparing all attributes. + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + + def __contains__(self, key): + return key in self.__dict__ + + def has_key(self, k): + return k in self.__dict__ + + def update(self, *args, **kwargs): + return self.__dict__.update(*args, **kwargs) + + def keys(self): + return [k for k in self.__dict__ if not k.startswith('_')] + + def values(self): + return [v for k, v in self.__dict__.items() if not k.startswith('_')] + + def items(self): + return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + + def get(self, key, default=None): + if key in self.__dict__: + return self.__dict__[key] + return default + + +class LocationMode(object): + """ + Specifies the location the request should be sent to. This mode only applies + for RA-GRS accounts which allow secondary read access. All other account types + must use PRIMARY. + """ + + PRIMARY = 'primary' #: Requests should be sent to the primary location. + SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible. + + +class ResourceTypes(object): + """ + Specifies the resource types that are accessible with the account SAS. + + :param bool service: + Access to service-level APIs (e.g., Get/Set Service Properties, + Get Service Stats, List Containers/Queues/Shares) + :param bool container: + Access to container-level APIs (e.g., Create/Delete Container, + Create/Delete Queue, Create/Delete Share, + List Blobs/Files and Directories) + :param bool object: + Access to object-level APIs for blobs, queue messages, and + files(e.g. Put Blob, Query Entity, Get Messages, Create File, etc.) + """ + + service: bool = False + container: bool = False + object: bool = False + _str: str + + def __init__( + self, + service: bool = False, + container: bool = False, + object: bool = False # pylint: disable=redefined-builtin + ) -> None: + self.service = service + self.container = container + self.object = object + self._str = (('s' if self.service else '') + + ('c' if self.container else '') + + ('o' if self.object else '')) + + def __str__(self): + return self._str + + @classmethod + def from_string(cls, string): + """Create a ResourceTypes from a string. + + To specify service, container, or object you need only to + include the first letter of the word in the string. E.g. service and container, + you would provide a string "sc". + + :param str string: Specify service, container, or object in + in the string with the first letter of the word. + :return: A ResourceTypes object + :rtype: ~azure.storage.fileshare.ResourceTypes + """ + res_service = 's' in string + res_container = 'c' in string + res_object = 'o' in string + + parsed = cls(res_service, res_container, res_object) + parsed._str = string + return parsed + + +class AccountSasPermissions(object): + """ + :class:`~ResourceTypes` class to be used with generate_account_sas + function and for the AccessPolicies used with set_*_acl. There are two types of + SAS which may be used to grant resource access. One is to grant access to a + specific resource (resource-specific). Another is to grant access to the + entire service for a specific account and allow certain operations based on + perms found here. + + :param bool read: + Valid for all signed resources types (Service, Container, and Object). + Permits read permissions to the specified resource type. + :param bool write: + Valid for all signed resources types (Service, Container, and Object). + Permits write permissions to the specified resource type. + :param bool delete: + Valid for Container and Object resource types, except for queue messages. + :param bool delete_previous_version: + Delete the previous blob version for the versioning enabled storage account. + :param bool list: + Valid for Service and Container resource types only. + :param bool add: + Valid for the following Object resource types only: queue messages, and append blobs. + :param bool create: + Valid for the following Object resource types only: blobs and files. + Users can create new blobs or files, but may not overwrite existing + blobs or files. + :param bool update: + Valid for the following Object resource types only: queue messages. + :param bool process: + Valid for the following Object resource type only: queue messages. + :keyword bool tag: + To enable set or get tags on the blobs in the container. + :keyword bool filter_by_tags: + To enable get blobs by tags, this should be used together with list permission. + :keyword bool set_immutability_policy: + To enable operations related to set/delete immutability policy. + To get immutability policy, you just need read permission. + :keyword bool permanent_delete: + To enable permanent delete on the blob is permitted. + Valid for Object resource type of Blob only. + """ + + read: bool = False + write: bool = False + delete: bool = False + delete_previous_version: bool = False + list: bool = False + add: bool = False + create: bool = False + update: bool = False + process: bool = False + tag: bool = False + filter_by_tags: bool = False + set_immutability_policy: bool = False + permanent_delete: bool = False + + def __init__( + self, + read: bool = False, + write: bool = False, + delete: bool = False, + list: bool = False, # pylint: disable=redefined-builtin + add: bool = False, + create: bool = False, + update: bool = False, + process: bool = False, + delete_previous_version: bool = False, + **kwargs + ) -> None: + self.read = read + self.write = write + self.delete = delete + self.delete_previous_version = delete_previous_version + self.permanent_delete = kwargs.pop('permanent_delete', False) + self.list = list + self.add = add + self.create = create + self.update = update + self.process = process + self.tag = kwargs.pop('tag', False) + self.filter_by_tags = kwargs.pop('filter_by_tags', False) + self.set_immutability_policy = kwargs.pop('set_immutability_policy', False) + self._str = (('r' if self.read else '') + + ('w' if self.write else '') + + ('d' if self.delete else '') + + ('x' if self.delete_previous_version else '') + + ('y' if self.permanent_delete else '') + + ('l' if self.list else '') + + ('a' if self.add else '') + + ('c' if self.create else '') + + ('u' if self.update else '') + + ('p' if self.process else '') + + ('f' if self.filter_by_tags else '') + + ('t' if self.tag else '') + + ('i' if self.set_immutability_policy else '') + ) + + def __str__(self): + return self._str + + @classmethod + def from_string(cls, permission): + """Create AccountSasPermissions from a string. + + To specify read, write, delete, etc. permissions you need only to + include the first letter of the word in the string. E.g. for read and write + permissions you would provide a string "rw". + + :param str permission: Specify permissions in + the string with the first letter of the word. + :return: An AccountSasPermissions object + :rtype: ~azure.storage.fileshare.AccountSasPermissions + """ + p_read = 'r' in permission + p_write = 'w' in permission + p_delete = 'd' in permission + p_delete_previous_version = 'x' in permission + p_permanent_delete = 'y' in permission + p_list = 'l' in permission + p_add = 'a' in permission + p_create = 'c' in permission + p_update = 'u' in permission + p_process = 'p' in permission + p_tag = 't' in permission + p_filter_by_tags = 'f' in permission + p_set_immutability_policy = 'i' in permission + parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version, + list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag, + filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy, + permanent_delete=p_permanent_delete) + + return parsed + + +class Services(object): + """Specifies the services accessible with the account SAS. + + :keyword bool blob: + Access for the `~azure.storage.blob.BlobServiceClient`. Default is False. + :keyword bool queue: + Access for the `~azure.storage.queue.QueueServiceClient`. Default is False. + :keyword bool fileshare: + Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False. + """ + + def __init__( + self, *, + blob: bool = False, + queue: bool = False, + fileshare: bool = False + ) -> None: + self.blob = blob + self.queue = queue + self.fileshare = fileshare + self._str = (('b' if self.blob else '') + + ('q' if self.queue else '') + + ('f' if self.fileshare else '')) + + def __str__(self): + return self._str + + @classmethod + def from_string(cls, string): + """Create Services from a string. + + To specify blob, queue, or file you need only to + include the first letter of the word in the string. E.g. for blob and queue + you would provide a string "bq". + + :param str string: Specify blob, queue, or file in + in the string with the first letter of the word. + :return: A Services object + :rtype: ~azure.storage.fileshare.Services + """ + res_blob = 'b' in string + res_queue = 'q' in string + res_file = 'f' in string + + parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file) + parsed._str = string + return parsed + + +class UserDelegationKey(object): + """ + Represents a user delegation key, provided to the user by Azure Storage + based on their Azure Active Directory access token. + + The fields are saved as simple strings since the user does not have to interact with this object; + to generate an identify SAS, the user can simply pass it to the right API. + """ + + signed_oid: Optional[str] = None + """Object ID of this token.""" + signed_tid: Optional[str] = None + """Tenant ID of the tenant that issued this token.""" + signed_start: Optional[str] = None + """The datetime this token becomes valid.""" + signed_expiry: Optional[str] = None + """The datetime this token expires.""" + signed_service: Optional[str] = None + """What service this key is valid for.""" + signed_version: Optional[str] = None + """The version identifier of the REST service that created this token.""" + value: Optional[str] = None + """The user delegation key.""" + + def __init__(self): + self.signed_oid = None + self.signed_tid = None + self.signed_start = None + self.signed_expiry = None + self.signed_service = None + self.signed_version = None + self.value = None + + +class StorageConfiguration(Configuration): + """ + Specifies the configurable values used in Azure Storage. + + :param int max_single_put_size: If the blob size is less than or equal max_single_put_size, then the blob will be + uploaded with only one http PUT request. If the blob size is larger than max_single_put_size, + the blob will be uploaded in chunks. Defaults to 64*1024*1024, or 64MB. + :param int copy_polling_interval: The interval in seconds for polling copy operations. + :param int max_block_size: The maximum chunk size for uploading a block blob in chunks. + Defaults to 4*1024*1024, or 4MB. + :param int min_large_block_upload_threshold: The minimum chunk size required to use the memory efficient + algorithm when uploading a block blob. + :param bool use_byte_buffer: Use a byte buffer for block blob uploads. Defaults to False. + :param int max_page_size: The maximum chunk size for uploading a page blob. Defaults to 4*1024*1024, or 4MB. + :param int min_large_chunk_upload_threshold: The max size for a single put operation. + :param int max_single_get_size: The maximum size for a blob to be downloaded in a single call, + the exceeded part will be downloaded in chunks (could be parallel). Defaults to 32*1024*1024, or 32MB. + :param int max_chunk_get_size: The maximum chunk size used for downloading a blob. Defaults to 4*1024*1024, + or 4MB. + :param int max_range_size: The max range size for file upload. + + """ + + max_single_put_size: int + copy_polling_interval: int + max_block_size: int + min_large_block_upload_threshold: int + use_byte_buffer: bool + max_page_size: int + min_large_chunk_upload_threshold: int + max_single_get_size: int + max_chunk_get_size: int + max_range_size: int + user_agent_policy: UserAgentPolicy + + def __init__(self, **kwargs): + super(StorageConfiguration, self).__init__(**kwargs) + self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.copy_polling_interval = 15 + self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) + self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) + self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py new file mode 100644 index 00000000..112c1984 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from datetime import datetime, timezone +from typing import Optional + +EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime +HUNDREDS_OF_NANOSECONDS = 10000000 + + +def _to_utc_datetime(value: datetime) -> str: + return value.strftime('%Y-%m-%dT%H:%M:%SZ') + + +def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: + """Converts an RFC 1123 date string to a UTC datetime. + + :param str rfc_1123: The time and date in RFC 1123 format. + :returns: The time and date in UTC datetime format. + :rtype: datetime + """ + if not rfc_1123: + return None + + return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z") + + +def _filetime_to_datetime(filetime: str) -> Optional[datetime]: + """Converts an MS filetime string to a UTC datetime. "0" indicates None. + If parsing MS Filetime fails, tries RFC 1123 as backup. + + :param str filetime: The time and date in MS filetime format. + :returns: The time and date in UTC datetime format. + :rtype: datetime + """ + if not filetime: + return None + + # Try to convert to MS Filetime + try: + temp_filetime = int(filetime) + if temp_filetime == 0: + return None + + return datetime.fromtimestamp((temp_filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc) + except ValueError: + pass + + # Try RFC 1123 as backup + return _rfc_1123_to_datetime(filetime) diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py new file mode 100644 index 00000000..ee75cd5a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py @@ -0,0 +1,694 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import base64 +import hashlib +import logging +import random +import re +import uuid +from io import SEEK_SET, UnsupportedOperation +from time import time +from typing import Any, Dict, Optional, TYPE_CHECKING +from urllib.parse import ( + parse_qsl, + urlencode, + urlparse, + urlunparse, +) +from wsgiref.handlers import format_date_time + +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError +from azure.core.pipeline.policies import ( + BearerTokenCredentialPolicy, + HeadersPolicy, + HTTPPolicy, + NetworkTraceLoggingPolicy, + RequestHistory, + SansIOHTTPPolicy +) + +from .authentication import AzureSigningError, StorageHttpChallenge +from .constants import DEFAULT_OAUTH_SCOPE +from .models import LocationMode + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + PipelineRequest, + PipelineResponse + ) + + +_LOGGER = logging.getLogger(__name__) + + +def encode_base64(data): + if isinstance(data, str): + data = data.encode('utf-8') + encoded = base64.b64encode(data) + return encoded.decode('utf-8') + + +# Are we out of retries? +def is_exhausted(settings): + retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = list(filter(None, retry_counts)) + if not retry_counts: + return False + return min(retry_counts) < 0 + + +def retry_hook(settings, **kwargs): + if settings['hook']: + settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + + +# Is this method/status code retryable? (Based on allowlists and control +# variables such as the number of total retries to allow, whether to +# respect the Retry-After header, whether this header is present, and +# whether the returned status code is on the list of status codes to +# be retried upon on the presence of the aforementioned header) +def is_retry(response, mode): + status = response.http_response.status_code + if 300 <= status < 500: + # An exception occurred, but in most cases it was expected. Examples could + # include a 309 Conflict or 412 Precondition Failed. + if status == 404 and mode == LocationMode.SECONDARY: + # Response code 404 should be retried if secondary was used. + return True + if status == 408: + # Response code 408 is a timeout and should be retried. + return True + return False + if status >= 500: + # Response codes above 500 with the exception of 501 Not Implemented and + # 505 Version Not Supported indicate a server issue and should be retried. + if status in [501, 505]: + return False + return True + return False + + +def is_checksum_retry(response): + # retry if invalid content md5 + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + computed_md5 = response.http_request.headers.get('content-md5', None) or \ + encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) + if response.http_response.headers['content-md5'] != computed_md5: + return True + return False + + +def urljoin(base_url, stub_url): + parsed = urlparse(base_url) + parsed = parsed._replace(path=parsed.path + '/' + stub_url) + return parsed.geturl() + + +class QueueMessagePolicy(SansIOHTTPPolicy): + + def on_request(self, request): + message_id = request.context.options.pop('queue_message_id', None) + if message_id: + request.http_request.url = urljoin( + request.http_request.url, + message_id) + + +class StorageHeadersPolicy(HeadersPolicy): + request_id_header_name = 'x-ms-client-request-id' + + def on_request(self, request: "PipelineRequest") -> None: + super(StorageHeadersPolicy, self).on_request(request) + current_time = format_date_time(time()) + request.http_request.headers['x-ms-date'] = current_time + + custom_id = request.context.options.pop('client_request_id', None) + request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + + # def on_response(self, request, response): + # # raise exception if the echoed client request id from the service is not identical to the one we sent + # if self.request_id_header_name in response.http_response.headers: + + # client_request_id = request.http_request.headers.get(self.request_id_header_name) + + # if response.http_response.headers[self.request_id_header_name] != client_request_id: + # raise AzureError( + # "Echoed client request ID: {} does not match sent client request ID: {}. " + # "Service request ID: {}".format( + # response.http_response.headers[self.request_id_header_name], client_request_id, + # response.http_response.headers['x-ms-request-id']), + # response=response.http_response + # ) + + +class StorageHosts(SansIOHTTPPolicy): + + def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument + self.hosts = hosts + super(StorageHosts, self).__init__() + + def on_request(self, request: "PipelineRequest") -> None: + request.context.options['hosts'] = self.hosts + parsed_url = urlparse(request.http_request.url) + + # Detect what location mode we're currently requesting with + location_mode = LocationMode.PRIMARY + for key, value in self.hosts.items(): + if parsed_url.netloc == value: + location_mode = key + + # See if a specific location mode has been specified, and if so, redirect + use_location = request.context.options.pop('use_location', None) + if use_location: + # Lock retries to the specific location + request.context.options['retry_to_secondary'] = False + if use_location not in self.hosts: + raise ValueError(f"Attempting to use undefined host location {use_location}") + if use_location != location_mode: + # Update request URL to use the specified location + updated = parsed_url._replace(netloc=self.hosts[use_location]) + request.http_request.url = updated.geturl() + location_mode = use_location + + request.context.options['location_mode'] = location_mode + + +class StorageLoggingPolicy(NetworkTraceLoggingPolicy): + """A policy that logs HTTP request and response to the DEBUG logger. + + This accepts both global configuration, and per-request level with "enable_http_logger" + """ + + def __init__(self, logging_enable: bool = False, **kwargs) -> None: + self.logging_body = kwargs.pop("logging_body", False) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + + def on_request(self, request: "PipelineRequest") -> None: + http_request = request.http_request + options = request.context.options + self.logging_body = self.logging_body or options.pop("logging_body", False) + if options.pop("logging_enable", self.enable_http_logger): + request.context["logging_enable"] = True + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + log_url = http_request.url + query_params = http_request.query + if 'sig' in query_params: + log_url = log_url.replace(query_params['sig'], "sig=*****") + _LOGGER.debug("Request URL: %r", log_url) + _LOGGER.debug("Request method: %r", http_request.method) + _LOGGER.debug("Request headers:") + for header, value in http_request.headers.items(): + if header.lower() == 'authorization': + value = '*****' + elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + # take the url apart and scrub away the signed signature + scheme, netloc, path, params, query, fragment = urlparse(value) + parsed_qs = dict(parse_qsl(query)) + parsed_qs['sig'] = '*****' + + # the SAS needs to be put back together + value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + + _LOGGER.debug(" %r: %r", header, value) + _LOGGER.debug("Request body:") + + if self.logging_body: + _LOGGER.debug(str(http_request.body)) + else: + # We don't want to log the binary data of a file upload. + _LOGGER.debug("Hidden body, please use logging_body to show body") + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log request: %r", err) + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + if response.context.pop("logging_enable", self.enable_http_logger): + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + _LOGGER.debug("Response status: %r", response.http_response.status_code) + _LOGGER.debug("Response headers:") + for res_header, value in response.http_response.headers.items(): + _LOGGER.debug(" %r: %r", res_header, value) + + # We don't want to log binary data if the response is a file. + _LOGGER.debug("Response content:") + pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) + header = response.http_response.headers.get('content-disposition') + resp_content_type = response.http_response.headers.get("content-type", "") + + if header and pattern.match(header): + filename = header.partition('=')[2] + _LOGGER.debug("File attachments: %s", filename) + elif resp_content_type.endswith("octet-stream"): + _LOGGER.debug("Body contains binary data.") + elif resp_content_type.startswith("image"): + _LOGGER.debug("Body contains image data.") + + if self.logging_body and resp_content_type.startswith("text"): + _LOGGER.debug(response.http_response.text()) + elif self.logging_body: + try: + _LOGGER.debug(response.http_response.body()) + except ValueError: + _LOGGER.debug("Body is streamable") + + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log response: %s", repr(err)) + + +class StorageRequestHook(SansIOHTTPPolicy): + + def __init__(self, **kwargs): + self._request_callback = kwargs.get('raw_request_hook') + super(StorageRequestHook, self).__init__() + + def on_request(self, request: "PipelineRequest") -> None: + request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + if request_callback: + request_callback(request) + + +class StorageResponseHook(HTTPPolicy): + + def __init__(self, **kwargs): + self._response_callback = kwargs.get('raw_response_hook') + super(StorageResponseHook, self).__init__() + + def send(self, request: "PipelineRequest") -> "PipelineResponse": + # Values could be 0 + data_stream_total = request.context.get('data_stream_total') + if data_stream_total is None: + data_stream_total = request.context.options.pop('data_stream_total', None) + download_stream_current = request.context.get('download_stream_current') + if download_stream_current is None: + download_stream_current = request.context.options.pop('download_stream_current', None) + upload_stream_current = request.context.get('upload_stream_current') + if upload_stream_current is None: + upload_stream_current = request.context.options.pop('upload_stream_current', None) + + response_callback = request.context.get('response_callback') or \ + request.context.options.pop('raw_response_hook', self._response_callback) + + response = self.next.send(request) + + will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + # Auth error could come from Bearer challenge, in which case this request will be made again + is_auth_error = response.http_response.status_code == 401 + should_update_counts = not (will_retry or is_auth_error) + + if should_update_counts and download_stream_current is not None: + download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + if data_stream_total is None: + content_range = response.http_response.headers.get('Content-Range') + if content_range: + data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + else: + data_stream_total = download_stream_current + elif should_update_counts and upload_stream_current is not None: + upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + for pipeline_obj in [request, response]: + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current + if response_callback: + response_callback(response) + request.context['response_callback'] = response_callback + return response + + +class StorageContentValidation(SansIOHTTPPolicy): + """A simple policy that sends the given headers + with the request. + + This will overwrite any headers already defined in the request. + """ + header_name = 'Content-MD5' + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super(StorageContentValidation, self).__init__() + + @staticmethod + def get_content_md5(data): + # Since HTTP does not differentiate between no content and empty content, + # we have to perform a None check. + data = data or b"" + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, 'read'): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError("Data should be bytes or a seekable file-like object.") from exc + else: + raise ValueError("Data should be bytes or a seekable file-like object.") + + return md5.digest() + + def on_request(self, request: "PipelineRequest") -> None: + validate_content = request.context.options.pop('validate_content', False) + if validate_content and request.http_request.method != 'GET': + computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) + request.http_request.headers[self.header_name] = computed_md5 + request.context['validate_content_md5'] = computed_md5 + request.context['validate_content'] = validate_content + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + computed_md5 = request.context.get('validate_content_md5') or \ + encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) + if response.http_response.headers['content-md5'] != computed_md5: + raise AzureError(( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'."), + response=response.http_response + ) + + +class StorageRetryPolicy(HTTPPolicy): + """ + The base class for Exponential and Linear retries containing shared code. + """ + + total_retries: int + """The max number of retries.""" + connect_retries: int + """The max number of connect retries.""" + retry_read: int + """The max number of read retries.""" + retry_status: int + """The max number of status retries.""" + retry_to_secondary: bool + """Whether the secondary endpoint should be retried.""" + + def __init__(self, **kwargs: Any) -> None: + self.total_retries = kwargs.pop('retry_total', 10) + self.connect_retries = kwargs.pop('retry_connect', 3) + self.read_retries = kwargs.pop('retry_read', 3) + self.status_retries = kwargs.pop('retry_status', 3) + self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + super(StorageRetryPolicy, self).__init__() + + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + """ + A function which sets the next host location on the request, if applicable. + + :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param PipelineRequest request: A pipeline request object. + """ + if settings['hosts'] and all(settings['hosts'].values()): + url = urlparse(request.url) + # If there's more than one possible location, retry to the alternative + if settings['mode'] == LocationMode.PRIMARY: + settings['mode'] = LocationMode.SECONDARY + else: + settings['mode'] = LocationMode.PRIMARY + updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + request.url = updated.geturl() + + def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + body_position = None + if hasattr(request.http_request.body, 'read'): + try: + body_position = request.http_request.body.tell() + except (AttributeError, UnsupportedOperation): + # if body position cannot be obtained, then retries will not work + pass + options = request.context.options + return { + 'total': options.pop("retry_total", self.total_retries), + 'connect': options.pop("retry_connect", self.connect_retries), + 'read': options.pop("retry_read", self.read_retries), + 'status': options.pop("retry_status", self.status_retries), + 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), + 'mode': options.pop("location_mode", LocationMode.PRIMARY), + 'hosts': options.pop("hosts", None), + 'hook': options.pop("retry_hook", None), + 'body_position': body_position, + 'count': 0, + 'history': [] + } + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + """ Formula for computing the current backoff. + Should be calculated by child class. + + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :returns: The backoff time. + :rtype: float + """ + return 0 + + def sleep(self, settings, transport): + backoff = self.get_backoff_time(settings) + if not backoff or backoff < 0: + return + transport.sleep(backoff) + + def increment( + self, settings: Dict[str, Any], + request: "PipelineRequest", + response: Optional["PipelineResponse"] = None, + error: Optional[AzureError] = None + ) -> bool: + """Increment the retry counters. + + :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. + :param PipelineRequest request: A pipeline request object. + :param Optional[PipelineResponse] response: A pipeline response object. + :param Optional[AzureError] error: An error encountered during the request, or + None if the response was received successfully. + :returns: Whether the retry attempts are exhausted. + :rtype: bool + """ + settings['total'] -= 1 + + if error and isinstance(error, ServiceRequestError): + # Errors when we're fairly sure that the server did not receive the + # request, so it should be safe to retry. + settings['connect'] -= 1 + settings['history'].append(RequestHistory(request, error=error)) + + elif error and isinstance(error, ServiceResponseError): + # Errors that occur after the request has been started, so we should + # assume that the server began processing it. + settings['read'] -= 1 + settings['history'].append(RequestHistory(request, error=error)) + + else: + # Incrementing because of a server error like a 500 in + # status_forcelist and a the given method is in the allowlist + if response: + settings['status'] -= 1 + settings['history'].append(RequestHistory(request, http_response=response)) + + if not is_exhausted(settings): + if request.method not in ['PUT'] and settings['retry_secondary']: + self._set_next_host_location(settings, request) + + # rewind the request body if it is a stream + if request.body and hasattr(request.body, 'read'): + # no position was saved, then retry would not work + if settings['body_position'] is None: + return False + try: + # attempt to rewind the body to the initial position + request.body.seek(settings['body_position'], SEEK_SET) + except (UnsupportedOperation, ValueError): + # if body is not seekable, then retry would not work + return False + settings['count'] += 1 + return True + return False + + def send(self, request): + retries_remaining = True + response = None + retry_settings = self.configure_retries(request) + while retries_remaining: + try: + response = self.next.send(request) + if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + retries_remaining = self.increment( + retry_settings, + request=request.http_request, + response=response.http_response) + if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=response.http_response, + error=None) + self.sleep(retry_settings, request.context.transport) + continue + break + except AzureError as err: + if isinstance(err, AzureSigningError): + raise + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err) + if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err) + self.sleep(retry_settings, request.context.transport) + continue + raise err + if retry_settings['history']: + response.context['history'] = retry_settings['history'] + response.http_response.location_mode = retry_settings['mode'] + return response + + +class ExponentialRetry(StorageRetryPolicy): + """Exponential retry.""" + + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ + Constructs an Exponential retry object. The initial_backoff is used for + the first retry. Subsequent retries are retried after initial_backoff + + increment_power^retry_count seconds. + + :param int initial_backoff: + The initial backoff interval, in seconds, for the first retry. + :param int increment_base: + The base, in seconds, to increment the initial_backoff by after the + first retry. + :param int retry_total: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.initial_backoff = initial_backoff + self.increment_base = increment_base + self.random_jitter_range = random_jitter_range + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. + :returns: + A float indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: float + """ + random_generator = random.Random() + backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + random_range_end = backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class LinearRetry(StorageRetryPolicy): + """Linear retry.""" + + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ + Constructs a Linear retry object. + + :param int backoff: + The backoff interval, in seconds, between retries. + :param int retry_total: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.backoff = backoff + self.random_jitter_range = random_jitter_range + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :returns: + A float indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: float + """ + random_generator = random.Random() + # the backoff interval normally does not change, however there is the possibility + # that it was modified by accessing the property directly after initializing the object + random_range_start = self.backoff - self.random_jitter_range \ + if self.backoff > self.random_jitter_range else 0 + random_range_end = self.backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): + """ Custom Bearer token credential policy for following Storage Bearer challenges """ + + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + try: + auth_header = response.http_response.headers.get("WWW-Authenticate") + challenge = StorageHttpChallenge(auth_header) + except ValueError: + return False + + scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE + self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + + return True diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py new file mode 100644 index 00000000..1c030a82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py @@ -0,0 +1,296 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=invalid-overridden-method + +import asyncio +import logging +import random +from typing import Any, Dict, TYPE_CHECKING + +from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy + +from .authentication import AzureSigningError, StorageHttpChallenge +from .constants import DEFAULT_OAUTH_SCOPE +from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + PipelineRequest, + PipelineResponse + ) + + +_LOGGER = logging.getLogger(__name__) + + +async def retry_hook(settings, **kwargs): + if settings['hook']: + if asyncio.iscoroutine(settings['hook']): + await settings['hook']( + retry_count=settings['count'] - 1, + location_mode=settings['mode'], + **kwargs) + else: + settings['hook']( + retry_count=settings['count'] - 1, + location_mode=settings['mode'], + **kwargs) + + +async def is_checksum_retry(response): + # retry if invalid content md5 + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + try: + await response.http_response.load_body() # Load the body in memory and close the socket + except (StreamClosedError, StreamConsumedError): + pass + computed_md5 = response.http_request.headers.get('content-md5', None) or \ + encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) + if response.http_response.headers['content-md5'] != computed_md5: + return True + return False + + +class AsyncStorageResponseHook(AsyncHTTPPolicy): + + def __init__(self, **kwargs): + self._response_callback = kwargs.get('raw_response_hook') + super(AsyncStorageResponseHook, self).__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + # Values could be 0 + data_stream_total = request.context.get('data_stream_total') + if data_stream_total is None: + data_stream_total = request.context.options.pop('data_stream_total', None) + download_stream_current = request.context.get('download_stream_current') + if download_stream_current is None: + download_stream_current = request.context.options.pop('download_stream_current', None) + upload_stream_current = request.context.get('upload_stream_current') + if upload_stream_current is None: + upload_stream_current = request.context.options.pop('upload_stream_current', None) + + response_callback = request.context.get('response_callback') or \ + request.context.options.pop('raw_response_hook', self._response_callback) + + response = await self.next.send(request) + + will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response) + # Auth error could come from Bearer challenge, in which case this request will be made again + is_auth_error = response.http_response.status_code == 401 + should_update_counts = not (will_retry or is_auth_error) + + if should_update_counts and download_stream_current is not None: + download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + if data_stream_total is None: + content_range = response.http_response.headers.get('Content-Range') + if content_range: + data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + else: + data_stream_total = download_stream_current + elif should_update_counts and upload_stream_current is not None: + upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + for pipeline_obj in [request, response]: + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current + if response_callback: + if asyncio.iscoroutine(response_callback): + await response_callback(response) # type: ignore + else: + response_callback(response) + request.context['response_callback'] = response_callback + return response + +class AsyncStorageRetryPolicy(StorageRetryPolicy): + """ + The base class for Exponential and Linear retries containing shared code. + """ + + async def sleep(self, settings, transport): + backoff = self.get_backoff_time(settings) + if not backoff or backoff < 0: + return + await transport.sleep(backoff) + + async def send(self, request): + retries_remaining = True + response = None + retry_settings = self.configure_retries(request) + while retries_remaining: + try: + response = await self.next.send(request) + if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response): + retries_remaining = self.increment( + retry_settings, + request=request.http_request, + response=response.http_response) + if retries_remaining: + await retry_hook( + retry_settings, + request=request.http_request, + response=response.http_response, + error=None) + await self.sleep(retry_settings, request.context.transport) + continue + break + except AzureError as err: + if isinstance(err, AzureSigningError): + raise + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err) + if retries_remaining: + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err) + await self.sleep(retry_settings, request.context.transport) + continue + raise err + if retry_settings['history']: + response.context['history'] = retry_settings['history'] + response.http_response.location_mode = retry_settings['mode'] + return response + + +class ExponentialRetry(AsyncStorageRetryPolicy): + """Exponential retry.""" + + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, + initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, **kwargs + ) -> None: + """ + Constructs an Exponential retry object. The initial_backoff is used for + the first retry. Subsequent retries are retried after initial_backoff + + increment_power^retry_count seconds. For example, by default the first retry + occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the + third after (15+3^2) = 24 seconds. + + :param int initial_backoff: + The initial backoff interval, in seconds, for the first retry. + :param int increment_base: + The base, in seconds, to increment the initial_backoff by after the + first retry. + :param int max_attempts: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.initial_backoff = initial_backoff + self.increment_base = increment_base + self.random_jitter_range = random_jitter_range + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: + An integer indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: int or None + """ + random_generator = random.Random() + backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + random_range_end = backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class LinearRetry(AsyncStorageRetryPolicy): + """Linear retry.""" + + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ + Constructs a Linear retry object. + + :param int backoff: + The backoff interval, in seconds, between retries. + :param int max_attempts: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.backoff = backoff + self.random_jitter_range = random_jitter_range + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: + An integer indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: int or None + """ + random_generator = random.Random() + # the backoff interval normally does not change, however there is the possibility + # that it was modified by accessing the property directly after initializing the object + random_range_start = self.backoff - self.random_jitter_range \ + if self.backoff > self.random_jitter_range else 0 + random_range_end = self.backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): + """ Custom Bearer token credential policy for following Storage Bearer challenges """ + + def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + try: + auth_header = response.http_response.headers.get("WWW-Authenticate") + challenge = StorageHttpChallenge(auth_header) + except ValueError: + return False + + scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE + await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + + return True diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py new file mode 100644 index 00000000..54927cc7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py @@ -0,0 +1,270 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging +import stat +from io import (SEEK_END, SEEK_SET, UnsupportedOperation) +from os import fstat +from typing import Dict, Optional + +import isodate + + +_LOGGER = logging.getLogger(__name__) + +_REQUEST_DELIMITER_PREFIX = "batch_" +_HTTP1_1_IDENTIFIER = "HTTP/1.1" +_HTTP_LINE_ENDING = "\r\n" + + +def serialize_iso(attr): + """Serialize Datetime object into ISO-8601 formatted string. + + :param Datetime attr: Object to be serialized. + :rtype: str + :raises: ValueError if format invalid. + """ + if not attr: + return None + if isinstance(attr, str): + attr = isodate.parse_datetime(attr) + try: + utc = attr.utctimetuple() + if utc.tm_year > 9999 or utc.tm_year < 1: + raise OverflowError("Hit max or min date") + + date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}" + return date + 'Z' + except (ValueError, OverflowError) as err: + raise ValueError("Unable to serialize datetime object.") from err + except AttributeError as err: + raise TypeError("ISO-8601 object must be valid datetime object.") from err + +def get_length(data): + length = None + # Check if object implements the __len__ method, covers most input cases such as bytearray. + try: + length = len(data) + except: # pylint: disable=bare-except + pass + + if not length: + # Check if the stream is a file-like stream object. + # If so, calculate the size using the file descriptor. + try: + fileno = data.fileno() + except (AttributeError, UnsupportedOperation): + pass + else: + try: + mode = fstat(fileno).st_mode + if stat.S_ISREG(mode) or stat.S_ISLNK(mode): + #st_size only meaningful if regular file or symlink, other types + # e.g. sockets may return misleading sizes like 0 + return fstat(fileno).st_size + except OSError: + # Not a valid fileno, may be possible requests returned + # a socket number? + pass + + # If the stream is seekable and tell() is implemented, calculate the stream size. + try: + current_position = data.tell() + data.seek(0, SEEK_END) + length = data.tell() - current_position + data.seek(current_position, SEEK_SET) + except (AttributeError, OSError, UnsupportedOperation): + pass + + return length + + +def read_length(data): + try: + if hasattr(data, 'read'): + read_data = b'' + for chunk in iter(lambda: data.read(4096), b""): + read_data += chunk + return len(read_data), read_data + if hasattr(data, '__iter__'): + read_data = b'' + for chunk in data: + read_data += chunk + return len(read_data), read_data + except: # pylint: disable=bare-except + pass + raise ValueError("Unable to calculate content length, please specify.") + + +def validate_and_format_range_headers( + start_range, end_range, start_range_required=True, + end_range_required=True, check_content_md5=False, align_to_page=False): + # If end range is provided, start range must be provided + if (start_range_required or end_range is not None) and start_range is None: + raise ValueError("start_range value cannot be None.") + if end_range_required and end_range is None: + raise ValueError("end_range value cannot be None.") + + # Page ranges must be 512 aligned + if align_to_page: + if start_range is not None and start_range % 512 != 0: + raise ValueError(f"Invalid page blob start_range: {start_range}. " + "The size must be aligned to a 512-byte boundary.") + if end_range is not None and end_range % 512 != 511: + raise ValueError(f"Invalid page blob end_range: {end_range}. " + "The size must be aligned to a 512-byte boundary.") + + # Format based on whether end_range is present + range_header = None + if end_range is not None: + range_header = f'bytes={start_range}-{end_range}' + elif start_range is not None: + range_header = f"bytes={start_range}-" + + # Content MD5 can only be provided for a complete range less than 4MB in size + range_validation = None + if check_content_md5: + if start_range is None or end_range is None: + raise ValueError("Both start and end range required for MD5 content validation.") + if end_range - start_range > 4 * 1024 * 1024: + raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") + range_validation = 'true' + + return range_header, range_validation + + +def add_metadata_headers(metadata=None): + # type: (Optional[Dict[str, str]]) -> Dict[str, str] + headers = {} + if metadata: + for key, value in metadata.items(): + headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value + return headers + + +def serialize_batch_body(requests, batch_id): + """ + --<delimiter> + <subrequest> + --<delimiter> + <subrequest> (repeated as needed) + --<delimiter>-- + + Serializes the requests in this batch to a single HTTP mixed/multipart body. + + :param List[~azure.core.pipeline.transport.HttpRequest] requests: + a list of sub-request for the batch request + :param str batch_id: + to be embedded in batch sub-request delimiter + :returns: The body bytes for this batch. + :rtype: bytes + """ + + if requests is None or len(requests) == 0: + raise ValueError('Please provide sub-request(s) for this batch request') + + delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8') + newline_bytes = _HTTP_LINE_ENDING.encode('utf-8') + batch_body = [] + + content_index = 0 + for request in requests: + request.headers.update({ + "Content-ID": str(content_index), + "Content-Length": str(0) + }) + batch_body.append(delimiter_bytes) + batch_body.append(_make_body_from_sub_request(request)) + batch_body.append(newline_bytes) + content_index += 1 + + batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8')) + # final line of body MUST have \r\n at the end, or it will not be properly read by the service + batch_body.append(newline_bytes) + + return b"".join(batch_body) + + +def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_dashes=False): + """ + Gets the delimiter used for this batch request's mixed/multipart HTTP format. + + :param str batch_id: + Randomly generated id + :param bool is_prepend_dashes: + Whether to include the starting dashes. Used in the body, but non on defining the delimiter. + :param bool is_append_dashes: + Whether to include the ending dashes. Used in the body on the closing delimiter only. + :returns: The delimiter, WITHOUT a trailing newline. + :rtype: str + """ + + prepend_dashes = '--' if is_prepend_dashes else '' + append_dashes = '--' if is_append_dashes else '' + + return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes + + +def _make_body_from_sub_request(sub_request): + """ + Content-Type: application/http + Content-ID: <sequential int ID> + Content-Transfer-Encoding: <value> (if present) + + <verb> <path><query> HTTP/<version> + <header key>: <header value> (repeated as necessary) + Content-Length: <value> + (newline if content length > 0) + <body> (if content length > 0) + + Serializes an http request. + + :param ~azure.core.pipeline.transport.HttpRequest sub_request: + Request to serialize. + :returns: The serialized sub-request in bytes + :rtype: bytes + """ + + # put the sub-request's headers into a list for efficient str concatenation + sub_request_body = [] + + # get headers for ease of manipulation; remove headers as they are used + headers = sub_request.headers + + # append opening headers + sub_request_body.append("Content-Type: application/http") + sub_request_body.append(_HTTP_LINE_ENDING) + + sub_request_body.append("Content-ID: ") + sub_request_body.append(headers.pop("Content-ID", "")) + sub_request_body.append(_HTTP_LINE_ENDING) + + sub_request_body.append("Content-Transfer-Encoding: binary") + sub_request_body.append(_HTTP_LINE_ENDING) + + # append blank line + sub_request_body.append(_HTTP_LINE_ENDING) + + # append HTTP verb and path and query and HTTP version + sub_request_body.append(sub_request.method) + sub_request_body.append(' ') + sub_request_body.append(sub_request.url) + sub_request_body.append(' ') + sub_request_body.append(_HTTP1_1_IDENTIFIER) + sub_request_body.append(_HTTP_LINE_ENDING) + + # append remaining headers (this will set the Content-Length, as it was set on `sub-request`) + for header_name, header_value in headers.items(): + if header_value is not None: + sub_request_body.append(header_name) + sub_request_body.append(": ") + sub_request_body.append(header_value) + sub_request_body.append(_HTTP_LINE_ENDING) + + # append blank line + sub_request_body.append(_HTTP_LINE_ENDING) + + return ''.join(sub_request_body).encode() diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py new file mode 100644 index 00000000..af9a2fcd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py @@ -0,0 +1,200 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +from typing import NoReturn +from xml.etree.ElementTree import Element + +from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, + HttpResponseError, + ResourceExistsError, + ResourceModifiedError, + ResourceNotFoundError, +) +from azure.core.pipeline.policies import ContentDecodePolicy + +from .authentication import AzureSigningError +from .models import get_enum_value, StorageErrorCode, UserDelegationKey +from .parser import _to_utc_datetime + + +_LOGGER = logging.getLogger(__name__) + + +class PartialBatchErrorException(HttpResponseError): + """There is a partial failure in batch operations. + + :param str message: The message of the exception. + :param response: Server response to be deserialized. + :param list parts: A list of the parts in multipart response. + """ + + def __init__(self, message, response, parts): + self.parts = parts + super(PartialBatchErrorException, self).__init__(message=message, response=response) + + +# Parses the blob length from the content range header: bytes 1-3/65537 +def parse_length_from_content_range(content_range): + if content_range is None: + return None + + # First, split in space and take the second half: '1-3/65537' + # Next, split on slash and take the second half: '65537' + # Finally, convert to an int: 65537 + return int(content_range.split(' ', 1)[1].split('/', 1)[1]) + + +def normalize_headers(headers): + normalized = {} + for key, value in headers.items(): + if key.startswith('x-ms-'): + key = key[5:] + normalized[key.lower().replace('-', '_')] = get_enum_value(value) + return normalized + + +def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument + try: + raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')} + except AttributeError: + raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')} + return {k[10:]: v for k, v in raw_metadata.items()} + + +def return_response_headers(response, deserialized, response_headers): # pylint: disable=unused-argument + return normalize_headers(response_headers) + + +def return_headers_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument + return normalize_headers(response_headers), deserialized + + +def return_context_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument + return response.http_response.location_mode, deserialized + + +def return_raw_deserialized(response, *_): + return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] + + +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches + raise_error = HttpResponseError + serialized = False + if isinstance(storage_error, AzureSigningError): + storage_error.message = storage_error.message + \ + '. This is likely due to an invalid shared key. Please check your shared key and try again.' + if not storage_error.response or storage_error.response.status_code in [200, 204]: + raise storage_error + # If it is one of those three then it has been serialized prior by the generated layer. + if isinstance(storage_error, (PartialBatchErrorException, + ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)): + serialized = True + error_code = storage_error.response.headers.get('x-ms-error-code') + error_message = storage_error.message + additional_data = {} + error_dict = {} + try: + error_body = ContentDecodePolicy.deserialize_from_http_generics(storage_error.response) + try: + if error_body is None or len(error_body) == 0: + error_body = storage_error.response.reason + except AttributeError: + error_body = '' + # If it is an XML response + if isinstance(error_body, Element): + error_dict = { + child.tag.lower(): child.text + for child in error_body + } + # If it is a JSON response + elif isinstance(error_body, dict): + error_dict = error_body.get('error', {}) + elif not error_code: + _LOGGER.warning( + 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) + error_dict = {'message': str(error_body)} + + # If we extracted from a Json or XML response + # There is a chance error_dict is just a string + if error_dict and isinstance(error_dict, dict): + error_code = error_dict.get('code') + error_message = error_dict.get('message') + additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}} + except DecodeError: + pass + + try: + # This check would be unnecessary if we have already serialized the error + if error_code and not serialized: + error_code = StorageErrorCode(error_code) + if error_code in [StorageErrorCode.condition_not_met, + StorageErrorCode.blob_overwritten]: + raise_error = ResourceModifiedError + if error_code in [StorageErrorCode.invalid_authentication_info, + StorageErrorCode.authentication_failed]: + raise_error = ClientAuthenticationError + if error_code in [StorageErrorCode.resource_not_found, + StorageErrorCode.cannot_verify_copy_source, + StorageErrorCode.blob_not_found, + StorageErrorCode.queue_not_found, + StorageErrorCode.container_not_found, + StorageErrorCode.parent_not_found, + StorageErrorCode.share_not_found]: + raise_error = ResourceNotFoundError + if error_code in [StorageErrorCode.account_already_exists, + StorageErrorCode.account_being_created, + StorageErrorCode.resource_already_exists, + StorageErrorCode.resource_type_mismatch, + StorageErrorCode.blob_already_exists, + StorageErrorCode.queue_already_exists, + StorageErrorCode.container_already_exists, + StorageErrorCode.container_being_deleted, + StorageErrorCode.queue_being_deleted, + StorageErrorCode.share_already_exists, + StorageErrorCode.share_being_deleted]: + raise_error = ResourceExistsError + except ValueError: + # Got an unknown error code + pass + + # Error message should include all the error properties + try: + error_message += f"\nErrorCode:{error_code.value}" + except AttributeError: + error_message += f"\nErrorCode:{error_code}" + for name, info in additional_data.items(): + error_message += f"\n{name}:{info}" + + # No need to create an instance if it has already been serialized by the generated layer + if serialized: + storage_error.message = error_message + error = storage_error + else: + error = raise_error(message=error_message, response=storage_error.response) + # Ensure these properties are stored in the error instance as well (not just the error message) + error.error_code = error_code + error.additional_info = additional_data + # error.args is what's surfaced on the traceback - show error message in all cases + error.args = (error.message,) + try: + # `from None` prevents us from double printing the exception (suppresses generated layer error context) + exec("raise error from None") # pylint: disable=exec-used # nosec + except SyntaxError as exc: + raise error from exc + + +def parse_to_internal_user_delegation_key(service_user_delegation_key): + internal_user_delegation_key = UserDelegationKey() + internal_user_delegation_key.signed_oid = service_user_delegation_key.signed_oid + internal_user_delegation_key.signed_tid = service_user_delegation_key.signed_tid + internal_user_delegation_key.signed_start = _to_utc_datetime(service_user_delegation_key.signed_start) + internal_user_delegation_key.signed_expiry = _to_utc_datetime(service_user_delegation_key.signed_expiry) + internal_user_delegation_key.signed_service = service_user_delegation_key.signed_service + internal_user_delegation_key.signed_version = service_user_delegation_key.signed_version + internal_user_delegation_key.value = service_user_delegation_key.value + return internal_user_delegation_key diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py new file mode 100644 index 00000000..2ef9921d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py @@ -0,0 +1,243 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from datetime import date + +from .parser import _to_utc_datetime +from .constants import X_MS_VERSION +from . import sign_string, url_quote + +# cspell:ignoreRegExp rsc. +# cspell:ignoreRegExp s..?id +class QueryStringConstants(object): + SIGNED_SIGNATURE = 'sig' + SIGNED_PERMISSION = 'sp' + SIGNED_START = 'st' + SIGNED_EXPIRY = 'se' + SIGNED_RESOURCE = 'sr' + SIGNED_IDENTIFIER = 'si' + SIGNED_IP = 'sip' + SIGNED_PROTOCOL = 'spr' + SIGNED_VERSION = 'sv' + SIGNED_CACHE_CONTROL = 'rscc' + SIGNED_CONTENT_DISPOSITION = 'rscd' + SIGNED_CONTENT_ENCODING = 'rsce' + SIGNED_CONTENT_LANGUAGE = 'rscl' + SIGNED_CONTENT_TYPE = 'rsct' + START_PK = 'spk' + START_RK = 'srk' + END_PK = 'epk' + END_RK = 'erk' + SIGNED_RESOURCE_TYPES = 'srt' + SIGNED_SERVICES = 'ss' + SIGNED_OID = 'skoid' + SIGNED_TID = 'sktid' + SIGNED_KEY_START = 'skt' + SIGNED_KEY_EXPIRY = 'ske' + SIGNED_KEY_SERVICE = 'sks' + SIGNED_KEY_VERSION = 'skv' + SIGNED_ENCRYPTION_SCOPE = 'ses' + + # for ADLS + SIGNED_AUTHORIZED_OID = 'saoid' + SIGNED_UNAUTHORIZED_OID = 'suoid' + SIGNED_CORRELATION_ID = 'scid' + SIGNED_DIRECTORY_DEPTH = 'sdd' + + @staticmethod + def to_list(): + return [ + QueryStringConstants.SIGNED_SIGNATURE, + QueryStringConstants.SIGNED_PERMISSION, + QueryStringConstants.SIGNED_START, + QueryStringConstants.SIGNED_EXPIRY, + QueryStringConstants.SIGNED_RESOURCE, + QueryStringConstants.SIGNED_IDENTIFIER, + QueryStringConstants.SIGNED_IP, + QueryStringConstants.SIGNED_PROTOCOL, + QueryStringConstants.SIGNED_VERSION, + QueryStringConstants.SIGNED_CACHE_CONTROL, + QueryStringConstants.SIGNED_CONTENT_DISPOSITION, + QueryStringConstants.SIGNED_CONTENT_ENCODING, + QueryStringConstants.SIGNED_CONTENT_LANGUAGE, + QueryStringConstants.SIGNED_CONTENT_TYPE, + QueryStringConstants.START_PK, + QueryStringConstants.START_RK, + QueryStringConstants.END_PK, + QueryStringConstants.END_RK, + QueryStringConstants.SIGNED_RESOURCE_TYPES, + QueryStringConstants.SIGNED_SERVICES, + QueryStringConstants.SIGNED_OID, + QueryStringConstants.SIGNED_TID, + QueryStringConstants.SIGNED_KEY_START, + QueryStringConstants.SIGNED_KEY_EXPIRY, + QueryStringConstants.SIGNED_KEY_SERVICE, + QueryStringConstants.SIGNED_KEY_VERSION, + QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, + # for ADLS + QueryStringConstants.SIGNED_AUTHORIZED_OID, + QueryStringConstants.SIGNED_UNAUTHORIZED_OID, + QueryStringConstants.SIGNED_CORRELATION_ID, + QueryStringConstants.SIGNED_DIRECTORY_DEPTH, + ] + + +class SharedAccessSignature(object): + ''' + Provides a factory for creating account access + signature tokens with an account name and account key. Users can either + use the factory or can construct the appropriate service and use the + generate_*_shared_access_signature method directly. + ''' + + def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION): + ''' + :param str account_name: + The storage account name used to generate the shared access signatures. + :param str account_key: + The access key to generate the shares access signatures. + :param str x_ms_version: + The service version used to generate the shared access signatures. + ''' + self.account_name = account_name + self.account_key = account_key + self.x_ms_version = x_ms_version + + def generate_account( + self, services, + resource_types, + permission, + expiry, + start=None, + ip=None, + protocol=None, + sts_hook=None + ) -> str: + ''' + Generates a shared access signature for the account. + Use the returned signature with the sas_token parameter of the service + or to create a new account object. + + :param Any services: The specified services associated with the shared access signature. + :param ResourceTypes resource_types: + Specifies the resource types that are accessible with the account + SAS. You can combine values to provide access to more than one + resource type. + :param AccountSasPermissions permission: + The permissions associated with the shared access signature. The + user is restricted to operations allowed by the permissions. + Required unless an id is given referencing a stored access policy + which contains this field. This field must be omitted if it has been + specified in an associated stored access policy. You can combine + values to provide more than one permission. + :param expiry: + The time at which the shared access signature becomes invalid. + Required unless an id is given referencing a stored access policy + which contains this field. This field must be omitted if it has + been specified in an associated stored access policy. Azure will always + convert values to UTC. If a date is passed in without timezone info, it + is assumed to be UTC. + :type expiry: datetime or str + :param start: + The time at which the shared access signature becomes valid. If + omitted, start time for this call is assumed to be the time when the + storage service receives the request. The provided datetime will always + be interpreted as UTC. + :type start: datetime or str + :param str ip: + Specifies an IP address or a range of IP addresses from which to accept requests. + If the IP address from which the request originates does not match the IP address + or address range specified on the SAS token, the request is not authenticated. + For example, specifying sip=168.1.5.65 or sip=168.1.5.60-168.1.5.70 on the SAS + restricts the request to those IP addresses. + :param str protocol: + Specifies the protocol permitted for a request made. The default value + is https,http. See :class:`~azure.storage.common.models.Protocol` for possible values. + :param sts_hook: + For debugging purposes only. If provided, the hook is called with the string to sign + that was used to generate the SAS. + :type sts_hook: Optional[Callable[[str], None]] + :returns: The generated SAS token for the account. + :rtype: str + ''' + sas = _SharedAccessHelper() + sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) + sas.add_account(services, resource_types) + sas.add_account_signature(self.account_name, self.account_key) + + if sts_hook is not None: + sts_hook(sas.string_to_sign) + + return sas.get_token() + + +class _SharedAccessHelper(object): + def __init__(self): + self.query_dict = {} + self.string_to_sign = "" + + def _add_query(self, name, val): + if val: + self.query_dict[name] = str(val) if val is not None else None + + def add_base(self, permission, expiry, start, ip, protocol, x_ms_version): + if isinstance(start, date): + start = _to_utc_datetime(start) + + if isinstance(expiry, date): + expiry = _to_utc_datetime(expiry) + + self._add_query(QueryStringConstants.SIGNED_START, start) + self._add_query(QueryStringConstants.SIGNED_EXPIRY, expiry) + self._add_query(QueryStringConstants.SIGNED_PERMISSION, permission) + self._add_query(QueryStringConstants.SIGNED_IP, ip) + self._add_query(QueryStringConstants.SIGNED_PROTOCOL, protocol) + self._add_query(QueryStringConstants.SIGNED_VERSION, x_ms_version) + + def add_resource(self, resource): + self._add_query(QueryStringConstants.SIGNED_RESOURCE, resource) + + def add_id(self, policy_id): + self._add_query(QueryStringConstants.SIGNED_IDENTIFIER, policy_id) + + def add_account(self, services, resource_types): + self._add_query(QueryStringConstants.SIGNED_SERVICES, services) + self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types) + + def add_override_response_headers(self, cache_control, + content_disposition, + content_encoding, + content_language, + content_type): + self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control) + self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition) + self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding) + self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language) + self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type) + + def add_account_signature(self, account_name, account_key): + def get_value_to_append(query): + return_value = self.query_dict.get(query) or '' + return return_value + '\n' + + self.string_to_sign = \ + (account_name + '\n' + + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + + get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + + '\n' # Signed Encryption Scope - always empty for fileshare + ) + + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, + sign_string(account_key, self.string_to_sign)) + + def get_token(self) -> str: + return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None]) diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py new file mode 100644 index 00000000..b31cfb32 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py @@ -0,0 +1,604 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from concurrent import futures +from io import BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation +from itertools import islice +from math import ceil +from threading import Lock + +from azure.core.tracing.common import with_current_context + +from .import encode_base64, url_quote +from .request_handlers import get_length +from .response_handlers import return_response_headers + + +_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024 +_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = "{0} should be a seekable file-like/io.IOBase type stream object." + + +def _parallel_uploads(executor, uploader, pending, running): + range_ids = [] + while True: + # Wait for some download to finish before adding a new one + done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED) + range_ids.extend([chunk.result() for chunk in done]) + try: + for _ in range(0, len(done)): + next_chunk = next(pending) + running.add(executor.submit(with_current_context(uploader), next_chunk)) + except StopIteration: + break + + # Wait for the remaining uploads to finish + done, _running = futures.wait(running) + range_ids.extend([chunk.result() for chunk in done]) + return range_ids + + +def upload_data_chunks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + validate_content=None, + progress_hook=None, + **kwargs): + + parallel = max_concurrency > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + validate_content=validate_content, + progress_hook=progress_hook, + **kwargs) + if parallel: + with futures.ThreadPoolExecutor(max_concurrency) as executor: + upload_tasks = uploader.get_chunk_streams() + running_futures = [ + executor.submit(with_current_context(uploader.process_chunk), u) + for u in islice(upload_tasks, 0, max_concurrency) + ] + range_ids = _parallel_uploads(executor, uploader.process_chunk, upload_tasks, running_futures) + else: + range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()] + if any(range_ids): + return [r[1] for r in sorted(range_ids, key=lambda r: r[0])] + return uploader.response_headers + + +def upload_substream_blocks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs): + parallel = max_concurrency > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + progress_hook=progress_hook, + **kwargs) + + if parallel: + with futures.ThreadPoolExecutor(max_concurrency) as executor: + upload_tasks = uploader.get_substream_blocks() + running_futures = [ + executor.submit(with_current_context(uploader.process_substream_block), u) + for u in islice(upload_tasks, 0, max_concurrency) + ] + range_ids = _parallel_uploads(executor, uploader.process_substream_block, upload_tasks, running_futures) + else: + range_ids = [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()] + if any(range_ids): + return sorted(range_ids) + return [] + + +class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes + + def __init__( + self, service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs): + self.service = service + self.total_size = total_size + self.chunk_size = chunk_size + self.stream = stream + self.parallel = parallel + + # Stream management + self.stream_lock = Lock() if parallel else None + + # Progress feedback + self.progress_total = 0 + self.progress_lock = Lock() if parallel else None + self.progress_hook = progress_hook + + # Encryption + self.encryptor = encryptor + self.padder = padder + self.response_headers = None + self.etag = None + self.last_modified = None + self.request_options = kwargs + + def get_chunk_streams(self): + index = 0 + while True: + data = b"" + read_size = self.chunk_size + + # Buffer until we either reach the end of the stream or get a whole chunk. + while True: + if self.total_size: + read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data))) + temp = self.stream.read(read_size) + if not isinstance(temp, bytes): + raise TypeError("Blob data should be of type bytes.") + data += temp or b"" + + # We have read an empty string and so are at the end + # of the buffer or we have read a full chunk. + if temp == b"" or len(data) == self.chunk_size: + break + + if len(data) == self.chunk_size: + if self.padder: + data = self.padder.update(data) + if self.encryptor: + data = self.encryptor.update(data) + yield index, data + else: + if self.padder: + data = self.padder.update(data) + self.padder.finalize() + if self.encryptor: + data = self.encryptor.update(data) + self.encryptor.finalize() + if data: + yield index, data + break + index += len(data) + + def process_chunk(self, chunk_data): + chunk_bytes = chunk_data[1] + chunk_offset = chunk_data[0] + return self._upload_chunk_with_progress(chunk_offset, chunk_bytes) + + def _update_progress(self, length): + if self.progress_lock is not None: + with self.progress_lock: + self.progress_total += length + else: + self.progress_total += length + + if self.progress_hook: + self.progress_hook(self.progress_total, self.total_size) + + def _upload_chunk(self, chunk_offset, chunk_data): + raise NotImplementedError("Must be implemented by child class.") + + def _upload_chunk_with_progress(self, chunk_offset, chunk_data): + range_id = self._upload_chunk(chunk_offset, chunk_data) + self._update_progress(len(chunk_data)) + return range_id + + def get_substream_blocks(self): + assert self.chunk_size is not None + lock = self.stream_lock + blob_length = self.total_size + + if blob_length is None: + blob_length = get_length(self.stream) + if blob_length is None: + raise ValueError("Unable to determine content length of upload data.") + + blocks = int(ceil(blob_length / (self.chunk_size * 1.0))) + last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size + + for i in range(blocks): + index = i * self.chunk_size + length = last_block_size if i == blocks - 1 else self.chunk_size + yield index, SubStream(self.stream, index, length, lock) + + def process_substream_block(self, block_data): + return self._upload_substream_block_with_progress(block_data[0], block_data[1]) + + def _upload_substream_block(self, index, block_stream): + raise NotImplementedError("Must be implemented by child class.") + + def _upload_substream_block_with_progress(self, index, block_stream): + range_id = self._upload_substream_block(index, block_stream) + self._update_progress(len(block_stream)) + return range_id + + def set_response_properties(self, resp): + self.etag = resp.etag + self.last_modified = resp.last_modified + + +class BlockBlobChunkUploader(_ChunkUploader): + + def __init__(self, *args, **kwargs): + kwargs.pop("modified_access_conditions", None) + super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) + self.current_length = None + + def _upload_chunk(self, chunk_offset, chunk_data): + # TODO: This is incorrect, but works with recording. + index = f'{chunk_offset:032d}' + block_id = encode_base64(url_quote(encode_base64(index))) + self.service.stage_block( + block_id, + len(chunk_data), + chunk_data, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + return index, block_id + + def _upload_substream_block(self, index, block_stream): + try: + block_id = f'BlockId{(index//self.chunk_size):05}' + self.service.stage_block( + block_id, + len(block_stream), + block_stream, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + finally: + block_stream.close() + return block_id + + +class PageBlobChunkUploader(_ChunkUploader): + + def _is_chunk_empty(self, chunk_data): + # read until non-zero byte is encountered + # if reached the end without returning, then chunk_data is all 0's + return not any(bytearray(chunk_data)) + + def _upload_chunk(self, chunk_offset, chunk_data): + # avoid uploading the empty pages + if not self._is_chunk_empty(chunk_data): + chunk_end = chunk_offset + len(chunk_data) - 1 + content_range = f"bytes={chunk_offset}-{chunk_end}" + computed_md5 = None + self.response_headers = self.service.upload_pages( + body=chunk_data, + content_length=len(chunk_data), + transactional_content_md5=computed_md5, + range=content_range, + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + + def _upload_substream_block(self, index, block_stream): + pass + + +class AppendBlobChunkUploader(_ChunkUploader): + + def __init__(self, *args, **kwargs): + super(AppendBlobChunkUploader, self).__init__(*args, **kwargs) + self.current_length = None + + def _upload_chunk(self, chunk_offset, chunk_data): + if self.current_length is None: + self.response_headers = self.service.append_block( + body=chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + self.current_length = int(self.response_headers["blob_append_offset"]) + else: + self.request_options['append_position_access_conditions'].append_position = \ + self.current_length + chunk_offset + self.response_headers = self.service.append_block( + body=chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + + def _upload_substream_block(self, index, block_stream): + pass + + +class DataLakeFileChunkUploader(_ChunkUploader): + + def _upload_chunk(self, chunk_offset, chunk_data): + # avoid uploading the empty pages + self.response_headers = self.service.append_data( + body=chunk_data, + position=chunk_offset, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + + def _upload_substream_block(self, index, block_stream): + try: + self.service.append_data( + body=block_stream, + position=index, + content_length=len(block_stream), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + finally: + block_stream.close() + + +class FileChunkUploader(_ChunkUploader): + + def _upload_chunk(self, chunk_offset, chunk_data): + length = len(chunk_data) + chunk_end = chunk_offset + length - 1 + response = self.service.upload_range( + chunk_data, + chunk_offset, + length, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + return f'bytes={chunk_offset}-{chunk_end}', response + + # TODO: Implement this method. + def _upload_substream_block(self, index, block_stream): + pass + + +class SubStream(IOBase): + + def __init__(self, wrapped_stream, stream_begin_index, length, lockObj): + # Python 2.7: file-like objects created with open() typically support seek(), but are not + # derivations of io.IOBase and thus do not implement seekable(). + # Python > 3.0: file-like objects created with open() are derived from io.IOBase. + try: + # only the main thread runs this, so there's no need grabbing the lock + wrapped_stream.seek(0, SEEK_CUR) + except Exception as exc: + raise ValueError("Wrapped stream must support seek().") from exc + + self._lock = lockObj + self._wrapped_stream = wrapped_stream + self._position = 0 + self._stream_begin_index = stream_begin_index + self._length = length + self._buffer = BytesIO() + + # we must avoid buffering more than necessary, and also not use up too much memory + # so the max buffer size is capped at 4MB + self._max_buffer_size = ( + length if length < _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE + ) + self._current_buffer_start = 0 + self._current_buffer_size = 0 + super(SubStream, self).__init__() + + def __len__(self): + return self._length + + def close(self): + if self._buffer: + self._buffer.close() + self._wrapped_stream = None + IOBase.close(self) + + def fileno(self): + return self._wrapped_stream.fileno() + + def flush(self): + pass + + def read(self, size=None): + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed.") + + if size is None: + size = self._length - self._position + + # adjust if out of bounds + if size + self._position >= self._length: + size = self._length - self._position + + # return fast + if size == 0 or self._buffer.closed: + return b"" + + # attempt first read from the read buffer and update position + read_buffer = self._buffer.read(size) + bytes_read = len(read_buffer) + bytes_remaining = size - bytes_read + self._position += bytes_read + + # repopulate the read buffer from the underlying stream to fulfill the request + # ensure the seek and read operations are done atomically (only if a lock is provided) + if bytes_remaining > 0: + with self._buffer: + # either read in the max buffer size specified on the class + # or read in just enough data for the current block/sub stream + current_max_buffer_size = min(self._max_buffer_size, self._length - self._position) + + # lock is only defined if max_concurrency > 1 (parallel uploads) + if self._lock: + with self._lock: + # reposition the underlying stream to match the start of the data to read + absolute_position = self._stream_begin_index + self._position + self._wrapped_stream.seek(absolute_position, SEEK_SET) + # If we can't seek to the right location, our read will be corrupted so fail fast. + if self._wrapped_stream.tell() != absolute_position: + raise IOError("Stream failed to seek to the desired location.") + buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size) + else: + absolute_position = self._stream_begin_index + self._position + # It's possible that there's connection problem during data transfer, + # so when we retry we don't want to read from current position of wrapped stream, + # instead we should seek to where we want to read from. + if self._wrapped_stream.tell() != absolute_position: + self._wrapped_stream.seek(absolute_position, SEEK_SET) + + buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size) + + if buffer_from_stream: + # update the buffer with new data from the wrapped stream + # we need to note down the start position and size of the buffer, in case seek is performed later + self._buffer = BytesIO(buffer_from_stream) + self._current_buffer_start = self._position + self._current_buffer_size = len(buffer_from_stream) + + # read the remaining bytes from the new buffer and update position + second_read_buffer = self._buffer.read(bytes_remaining) + read_buffer += second_read_buffer + self._position += len(second_read_buffer) + + return read_buffer + + def readable(self): + return True + + def readinto(self, b): + raise UnsupportedOperation + + def seek(self, offset, whence=0): + if whence is SEEK_SET: + start_index = 0 + elif whence is SEEK_CUR: + start_index = self._position + elif whence is SEEK_END: + start_index = self._length + offset = -offset + else: + raise ValueError("Invalid argument for the 'whence' parameter.") + + pos = start_index + offset + + if pos > self._length: + pos = self._length + elif pos < 0: + pos = 0 + + # check if buffer is still valid + # if not, drop buffer + if pos < self._current_buffer_start or pos >= self._current_buffer_start + self._current_buffer_size: + self._buffer.close() + self._buffer = BytesIO() + else: # if yes seek to correct position + delta = pos - self._current_buffer_start + self._buffer.seek(delta, SEEK_SET) + + self._position = pos + return pos + + def seekable(self): + return True + + def tell(self): + return self._position + + def write(self): + raise UnsupportedOperation + + def writelines(self): + raise UnsupportedOperation + + def writeable(self): + return False + + +class IterStreamer(object): + """ + File-like streaming iterator. + """ + + def __init__(self, generator, encoding="UTF-8"): + self.generator = generator + self.iterator = iter(generator) + self.leftover = b"" + self.encoding = encoding + + def __len__(self): + return self.generator.__len__() + + def __iter__(self): + return self.iterator + + def seekable(self): + return False + + def __next__(self): + return next(self.iterator) + + def tell(self, *args, **kwargs): + raise UnsupportedOperation("Data generator does not support tell.") + + def seek(self, *args, **kwargs): + raise UnsupportedOperation("Data generator is not seekable.") + + def read(self, size): + data = self.leftover + count = len(self.leftover) + try: + while count < size: + chunk = self.__next__() + if isinstance(chunk, str): + chunk = chunk.encode(self.encoding) + data += chunk + count += len(chunk) + # This means count < size and what's leftover will be returned in this call. + except StopIteration: + self.leftover = b"" + + if count >= size: + self.leftover = data[size:] + + return data[:size] diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py new file mode 100644 index 00000000..3e102ec5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py @@ -0,0 +1,460 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import asyncio +import inspect +import threading +from asyncio import Lock +from io import UnsupportedOperation +from itertools import islice +from math import ceil +from typing import AsyncGenerator, Union + +from .import encode_base64, url_quote +from .request_handlers import get_length +from .response_handlers import return_response_headers +from .uploads import SubStream, IterStreamer # pylint: disable=unused-import + + +async def _async_parallel_uploads(uploader, pending, running): + range_ids = [] + while True: + # Wait for some download to finish before adding a new one + done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED) + range_ids.extend([chunk.result() for chunk in done]) + try: + for _ in range(0, len(done)): + next_chunk = await pending.__anext__() + running.add(asyncio.ensure_future(uploader(next_chunk))) + except StopAsyncIteration: + break + + # Wait for the remaining uploads to finish + if running: + done, _running = await asyncio.wait(running) + range_ids.extend([chunk.result() for chunk in done]) + return range_ids + + +async def _parallel_uploads(uploader, pending, running): + range_ids = [] + while True: + # Wait for some download to finish before adding a new one + done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED) + range_ids.extend([chunk.result() for chunk in done]) + try: + for _ in range(0, len(done)): + next_chunk = next(pending) + running.add(asyncio.ensure_future(uploader(next_chunk))) + except StopIteration: + break + + # Wait for the remaining uploads to finish + if running: + done, _running = await asyncio.wait(running) + range_ids.extend([chunk.result() for chunk in done]) + return range_ids + + +async def upload_data_chunks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs): + + parallel = max_concurrency > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + progress_hook=progress_hook, + **kwargs) + + if parallel: + upload_tasks = uploader.get_chunk_streams() + running_futures = [] + for _ in range(max_concurrency): + try: + chunk = await upload_tasks.__anext__() + running_futures.append(asyncio.ensure_future(uploader.process_chunk(chunk))) + except StopAsyncIteration: + break + + range_ids = await _async_parallel_uploads(uploader.process_chunk, upload_tasks, running_futures) + else: + range_ids = [] + async for chunk in uploader.get_chunk_streams(): + range_ids.append(await uploader.process_chunk(chunk)) + + if any(range_ids): + return [r[1] for r in sorted(range_ids, key=lambda r: r[0])] + return uploader.response_headers + + +async def upload_substream_blocks( + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs): + parallel = max_concurrency > 1 + if parallel and 'modified_access_conditions' in kwargs: + # Access conditions do not work with parallelism + kwargs['modified_access_conditions'] = None + uploader = uploader_class( + service=service, + total_size=total_size, + chunk_size=chunk_size, + stream=stream, + parallel=parallel, + progress_hook=progress_hook, + **kwargs) + + if parallel: + upload_tasks = uploader.get_substream_blocks() + running_futures = [ + asyncio.ensure_future(uploader.process_substream_block(u)) + for u in islice(upload_tasks, 0, max_concurrency) + ] + range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures) + else: + range_ids = [] + for block in uploader.get_substream_blocks(): + range_ids.append(await uploader.process_substream_block(block)) + if any(range_ids): + return sorted(range_ids) + return + + +class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes + + def __init__( + self, service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs): + self.service = service + self.total_size = total_size + self.chunk_size = chunk_size + self.stream = stream + self.parallel = parallel + + # Stream management + self.stream_lock = threading.Lock() if parallel else None + + # Progress feedback + self.progress_total = 0 + self.progress_lock = Lock() if parallel else None + self.progress_hook = progress_hook + + # Encryption + self.encryptor = encryptor + self.padder = padder + self.response_headers = None + self.etag = None + self.last_modified = None + self.request_options = kwargs + + async def get_chunk_streams(self): + index = 0 + while True: + data = b'' + read_size = self.chunk_size + + # Buffer until we either reach the end of the stream or get a whole chunk. + while True: + if self.total_size: + read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data))) + temp = self.stream.read(read_size) + if inspect.isawaitable(temp): + temp = await temp + if not isinstance(temp, bytes): + raise TypeError('Blob data should be of type bytes.') + data += temp or b"" + + # We have read an empty string and so are at the end + # of the buffer or we have read a full chunk. + if temp == b'' or len(data) == self.chunk_size: + break + + if len(data) == self.chunk_size: + if self.padder: + data = self.padder.update(data) + if self.encryptor: + data = self.encryptor.update(data) + yield index, data + else: + if self.padder: + data = self.padder.update(data) + self.padder.finalize() + if self.encryptor: + data = self.encryptor.update(data) + self.encryptor.finalize() + if data: + yield index, data + break + index += len(data) + + async def process_chunk(self, chunk_data): + chunk_bytes = chunk_data[1] + chunk_offset = chunk_data[0] + return await self._upload_chunk_with_progress(chunk_offset, chunk_bytes) + + async def _update_progress(self, length): + if self.progress_lock is not None: + async with self.progress_lock: + self.progress_total += length + else: + self.progress_total += length + + if self.progress_hook: + await self.progress_hook(self.progress_total, self.total_size) + + async def _upload_chunk(self, chunk_offset, chunk_data): + raise NotImplementedError("Must be implemented by child class.") + + async def _upload_chunk_with_progress(self, chunk_offset, chunk_data): + range_id = await self._upload_chunk(chunk_offset, chunk_data) + await self._update_progress(len(chunk_data)) + return range_id + + def get_substream_blocks(self): + assert self.chunk_size is not None + lock = self.stream_lock + blob_length = self.total_size + + if blob_length is None: + blob_length = get_length(self.stream) + if blob_length is None: + raise ValueError("Unable to determine content length of upload data.") + + blocks = int(ceil(blob_length / (self.chunk_size * 1.0))) + last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size + + for i in range(blocks): + index = i * self.chunk_size + length = last_block_size if i == blocks - 1 else self.chunk_size + yield index, SubStream(self.stream, index, length, lock) + + async def process_substream_block(self, block_data): + return await self._upload_substream_block_with_progress(block_data[0], block_data[1]) + + async def _upload_substream_block(self, index, block_stream): + raise NotImplementedError("Must be implemented by child class.") + + async def _upload_substream_block_with_progress(self, index, block_stream): + range_id = await self._upload_substream_block(index, block_stream) + await self._update_progress(len(block_stream)) + return range_id + + def set_response_properties(self, resp): + self.etag = resp.etag + self.last_modified = resp.last_modified + + +class BlockBlobChunkUploader(_ChunkUploader): + + def __init__(self, *args, **kwargs): + kwargs.pop('modified_access_conditions', None) + super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) + self.current_length = None + + async def _upload_chunk(self, chunk_offset, chunk_data): + # TODO: This is incorrect, but works with recording. + index = f'{chunk_offset:032d}' + block_id = encode_base64(url_quote(encode_base64(index))) + await self.service.stage_block( + block_id, + len(chunk_data), + body=chunk_data, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + return index, block_id + + async def _upload_substream_block(self, index, block_stream): + try: + block_id = f'BlockId{(index//self.chunk_size):05}' + await self.service.stage_block( + block_id, + len(block_stream), + block_stream, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + finally: + block_stream.close() + return block_id + + +class PageBlobChunkUploader(_ChunkUploader): + + def _is_chunk_empty(self, chunk_data): + # read until non-zero byte is encountered + # if reached the end without returning, then chunk_data is all 0's + for each_byte in chunk_data: + if each_byte not in [0, b'\x00']: + return False + return True + + async def _upload_chunk(self, chunk_offset, chunk_data): + # avoid uploading the empty pages + if not self._is_chunk_empty(chunk_data): + chunk_end = chunk_offset + len(chunk_data) - 1 + content_range = f'bytes={chunk_offset}-{chunk_end}' + computed_md5 = None + self.response_headers = await self.service.upload_pages( + body=chunk_data, + content_length=len(chunk_data), + transactional_content_md5=computed_md5, + range=content_range, + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + + async def _upload_substream_block(self, index, block_stream): + pass + + +class AppendBlobChunkUploader(_ChunkUploader): + + def __init__(self, *args, **kwargs): + super(AppendBlobChunkUploader, self).__init__(*args, **kwargs) + self.current_length = None + + async def _upload_chunk(self, chunk_offset, chunk_data): + if self.current_length is None: + self.response_headers = await self.service.append_block( + body=chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + self.current_length = int(self.response_headers['blob_append_offset']) + else: + self.request_options['append_position_access_conditions'].append_position = \ + self.current_length + chunk_offset + self.response_headers = await self.service.append_block( + body=chunk_data, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options) + + async def _upload_substream_block(self, index, block_stream): + pass + + +class DataLakeFileChunkUploader(_ChunkUploader): + + async def _upload_chunk(self, chunk_offset, chunk_data): + self.response_headers = await self.service.append_data( + body=chunk_data, + position=chunk_offset, + content_length=len(chunk_data), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + + if not self.parallel and self.request_options.get('modified_access_conditions'): + self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + + async def _upload_substream_block(self, index, block_stream): + try: + await self.service.append_data( + body=block_stream, + position=index, + content_length=len(block_stream), + cls=return_response_headers, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + finally: + block_stream.close() + + +class FileChunkUploader(_ChunkUploader): + + async def _upload_chunk(self, chunk_offset, chunk_data): + length = len(chunk_data) + chunk_end = chunk_offset + length - 1 + response = await self.service.upload_range( + chunk_data, + chunk_offset, + length, + data_stream_total=self.total_size, + upload_stream_current=self.progress_total, + **self.request_options + ) + range_id = f'bytes={chunk_offset}-{chunk_end}' + return range_id, response + + # TODO: Implement this method. + async def _upload_substream_block(self, index, block_stream): + pass + + +class AsyncIterStreamer(): + """ + File-like streaming object for AsyncGenerators. + """ + def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"): + self.iterator = generator.__aiter__() + self.leftover = b"" + self.encoding = encoding + + def seekable(self): + return False + + def tell(self, *args, **kwargs): + raise UnsupportedOperation("Data generator does not support tell.") + + def seek(self, *args, **kwargs): + raise UnsupportedOperation("Data generator is not seekable.") + + async def read(self, size: int) -> bytes: + data = self.leftover + count = len(self.leftover) + try: + while count < size: + chunk = await self.iterator.__anext__() + if isinstance(chunk, str): + chunk = chunk.encode(self.encoding) + data += chunk + count += len(chunk) + # This means count < size and what's leftover will be returned in this call. + except StopAsyncIteration: + self.leftover = b"" + + if count >= size: + self.leftover = data[size:] + + return data[:size] |