aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py244
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py458
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py280
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py585
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py694
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py296
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py270
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py200
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py243
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py604
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py460
14 files changed, 4460 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py
new file mode 100644
index 00000000..a8b1a27d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/__init__.py
@@ -0,0 +1,54 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import base64
+import hashlib
+import hmac
+
+try:
+ from urllib.parse import quote, unquote
+except ImportError:
+ from urllib2 import quote, unquote # type: ignore
+
+
+def url_quote(url):
+ return quote(url)
+
+
+def url_unquote(url):
+ return unquote(url)
+
+
+def encode_base64(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ encoded = base64.b64encode(data)
+ return encoded.decode('utf-8')
+
+
+def decode_base64_to_bytes(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ return base64.b64decode(data)
+
+
+def decode_base64_to_text(data):
+ decoded_bytes = decode_base64_to_bytes(data)
+ return decoded_bytes.decode('utf-8')
+
+
+def sign_string(key, string_to_sign, key_is_base64=True):
+ if key_is_base64:
+ key = decode_base64_to_bytes(key)
+ else:
+ if isinstance(key, str):
+ key = key.encode('utf-8')
+ if isinstance(string_to_sign, str):
+ string_to_sign = string_to_sign.encode('utf-8')
+ signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256)
+ digest = signed_hmac_sha256.digest()
+ encoded_digest = encode_base64(digest)
+ return encoded_digest
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py
new file mode 100644
index 00000000..44c563d8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/authentication.py
@@ -0,0 +1,244 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import logging
+import re
+from typing import List, Tuple
+from urllib.parse import unquote, urlparse
+from functools import cmp_to_key
+
+try:
+ from yarl import URL
+except ImportError:
+ pass
+
+try:
+ from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import
+except ImportError:
+ AioHttpTransport = None
+
+from azure.core.exceptions import ClientAuthenticationError
+from azure.core.pipeline.policies import SansIOHTTPPolicy
+
+from . import sign_string
+
+logger = logging.getLogger(__name__)
+
+table_lv0 = [
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x71c, 0x0, 0x71f, 0x721, 0x723, 0x725,
+ 0x0, 0x0, 0x0, 0x72d, 0x803, 0x0, 0x0, 0x733, 0x0, 0xd03, 0xd1a, 0xd1c, 0xd1e,
+ 0xd20, 0xd22, 0xd24, 0xd26, 0xd28, 0xd2a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25, 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51,
+ 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99, 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9,
+ 0x0, 0x0, 0x0, 0x743, 0x744, 0x748, 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25,
+ 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51, 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99,
+ 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9, 0x0, 0x74c, 0x0, 0x750, 0x0,
+]
+
+table_lv4 = [
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8012, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8212, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+]
+
+def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements
+ tables = [table_lv0, table_lv4]
+ curr_level, i, j, n = 0, 0, 0, len(tables)
+ lhs_len = len(lhs)
+ rhs_len = len(rhs)
+ while curr_level < n:
+ if curr_level == (n - 1) and i != j:
+ if i > j:
+ return -1
+ if i < j:
+ return 1
+ return 0
+
+ w1 = tables[curr_level][ord(lhs[i])] if i < lhs_len else 0x1
+ w2 = tables[curr_level][ord(rhs[j])] if j < rhs_len else 0x1
+
+ if w1 == 0x1 and w2 == 0x1:
+ i = 0
+ j = 0
+ curr_level += 1
+ elif w1 == w2:
+ i += 1
+ j += 1
+ elif w1 == 0:
+ i += 1
+ elif w2 == 0:
+ j += 1
+ else:
+ if w1 < w2:
+ return -1
+ if w1 > w2:
+ return 1
+ return 0
+ return 0
+
+
+# wraps a given exception with the desired exception type
+def _wrap_exception(ex, desired_type):
+ msg = ""
+ if ex.args:
+ msg = ex.args[0]
+ return desired_type(msg)
+
+# This method attempts to emulate the sorting done by the service
+def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
+
+ # Build dict of tuples and list of keys
+ header_dict = {}
+ header_keys = []
+ for k, v in input_headers:
+ header_dict[k] = v
+ header_keys.append(k)
+
+ try:
+ header_keys = sorted(header_keys, key=cmp_to_key(compare))
+ except ValueError as exc:
+ raise ValueError("Illegal character encountered when sorting headers.") from exc
+
+ # Build list of sorted tuples
+ sorted_headers = []
+ for key in header_keys:
+ sorted_headers.append((key, header_dict.pop(key)))
+ return sorted_headers
+
+
+class AzureSigningError(ClientAuthenticationError):
+ """
+ Represents a fatal error when attempting to sign a request.
+ In general, the cause of this exception is user error. For example, the given account key is not valid.
+ Please visit https://learn.microsoft.com/azure/storage/common/storage-create-storage-account for more info.
+ """
+
+
+class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
+
+ def __init__(self, account_name, account_key):
+ self.account_name = account_name
+ self.account_key = account_key
+ super(SharedKeyCredentialPolicy, self).__init__()
+
+ @staticmethod
+ def _get_headers(request, headers_to_sign):
+ headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value)
+ if 'content-length' in headers and headers['content-length'] == '0':
+ del headers['content-length']
+ return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n'
+
+ @staticmethod
+ def _get_verb(request):
+ return request.http_request.method + '\n'
+
+ def _get_canonicalized_resource(self, request):
+ uri_path = urlparse(request.http_request.url).path
+ try:
+ if isinstance(request.context.transport, AioHttpTransport) or \
+ isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \
+ isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None),
+ AioHttpTransport):
+ uri_path = URL(uri_path)
+ return '/' + self.account_name + str(uri_path)
+ except TypeError:
+ pass
+ return '/' + self.account_name + uri_path
+
+ @staticmethod
+ def _get_canonicalized_headers(request):
+ string_to_sign = ''
+ x_ms_headers = []
+ for name, value in request.http_request.headers.items():
+ if name.startswith('x-ms-'):
+ x_ms_headers.append((name.lower(), value))
+ x_ms_headers = _storage_header_sort(x_ms_headers)
+ for name, value in x_ms_headers:
+ if value is not None:
+ string_to_sign += ''.join([name, ':', value, '\n'])
+ return string_to_sign
+
+ @staticmethod
+ def _get_canonicalized_resource_query(request):
+ sorted_queries = list(request.http_request.query.items())
+ sorted_queries.sort()
+
+ string_to_sign = ''
+ for name, value in sorted_queries:
+ if value is not None:
+ string_to_sign += '\n' + name.lower() + ':' + unquote(value)
+
+ return string_to_sign
+
+ def _add_authorization_header(self, request, string_to_sign):
+ try:
+ signature = sign_string(self.account_key, string_to_sign)
+ auth_string = 'SharedKey ' + self.account_name + ':' + signature
+ request.http_request.headers['Authorization'] = auth_string
+ except Exception as ex:
+ # Wrap any error that occurred as signing error
+ # Doing so will clarify/locate the source of problem
+ raise _wrap_exception(ex, AzureSigningError) from ex
+
+ def on_request(self, request):
+ string_to_sign = \
+ self._get_verb(request) + \
+ self._get_headers(
+ request,
+ [
+ 'content-encoding', 'content-language', 'content-length',
+ 'content-md5', 'content-type', 'date', 'if-modified-since',
+ 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range'
+ ]
+ ) + \
+ self._get_canonicalized_headers(request) + \
+ self._get_canonicalized_resource(request) + \
+ self._get_canonicalized_resource_query(request)
+
+ self._add_authorization_header(request, string_to_sign)
+ # logger.debug("String_to_sign=%s", string_to_sign)
+
+
+class StorageHttpChallenge(object):
+ def __init__(self, challenge):
+ """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """
+ if not challenge:
+ raise ValueError("Challenge cannot be empty")
+
+ self._parameters = {}
+ self.scheme, trimmed_challenge = challenge.strip().split(" ", 1)
+
+ # name=value pairs either comma or space separated with values possibly being
+ # enclosed in quotes
+ for item in re.split('[, ]', trimmed_challenge):
+ comps = item.split("=")
+ if len(comps) == 2:
+ key = comps[0].strip(' "')
+ value = comps[1].strip(' "')
+ if key:
+ self._parameters[key] = value
+
+ # Extract and verify required parameters
+ self.authorization_uri = self._parameters.get('authorization_uri')
+ if not self.authorization_uri:
+ raise ValueError("Authorization Uri not found")
+
+ self.resource_id = self._parameters.get('resource_id')
+ if not self.resource_id:
+ raise ValueError("Resource id not found")
+
+ uri_path = urlparse(self.authorization_uri).path.lstrip("/")
+ self.tenant_id = uri_path.split("/")[0]
+
+ def get_value(self, key):
+ return self._parameters.get(key)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py
new file mode 100644
index 00000000..9dc8d2ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client.py
@@ -0,0 +1,458 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import logging
+import uuid
+from typing import (
+ Any,
+ cast,
+ Dict,
+ Iterator,
+ Optional,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
+from urllib.parse import parse_qs, quote
+
+from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential
+from azure.core.exceptions import HttpResponseError
+from azure.core.pipeline import Pipeline
+from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module
+from azure.core.pipeline.policies import (
+ AzureSasCredentialPolicy,
+ ContentDecodePolicy,
+ DistributedTracingPolicy,
+ HttpLoggingPolicy,
+ ProxyPolicy,
+ RedirectPolicy,
+ UserAgentPolicy,
+)
+
+from .authentication import SharedKeyCredentialPolicy
+from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE
+from .models import LocationMode, StorageConfiguration
+from .policies import (
+ ExponentialRetry,
+ QueueMessagePolicy,
+ StorageBearerTokenCredentialPolicy,
+ StorageContentValidation,
+ StorageHeadersPolicy,
+ StorageHosts,
+ StorageLoggingPolicy,
+ StorageRequestHook,
+ StorageResponseHook,
+)
+from .request_handlers import serialize_batch_body, _get_batch_request_delimiter
+from .response_handlers import PartialBatchErrorException, process_storage_error
+from .shared_access_signature import QueryStringConstants
+from .._version import VERSION
+from .._shared_access_signature import _is_credential_sastoken
+
+if TYPE_CHECKING:
+ from azure.core.credentials_async import AsyncTokenCredential
+ from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756
+
+_LOGGER = logging.getLogger(__name__)
+_SERVICE_PARAMS = {
+ "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"},
+ "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"},
+ "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"},
+ "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"},
+}
+
+
+class StorageAccountHostsMixin(object):
+ _client: Any
+ def __init__(
+ self,
+ parsed_url: Any,
+ service: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> None:
+ self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
+ self._hosts = kwargs.get("_hosts")
+ self.scheme = parsed_url.scheme
+ self._is_localhost = False
+
+ if service not in ["blob", "queue", "file-share", "dfs"]:
+ raise ValueError(f"Invalid service: {service}")
+ service_name = service.split('-')[0]
+ account = parsed_url.netloc.split(f".{service_name}.core.")
+
+ self.account_name = account[0] if len(account) > 1 else None
+ if not self.account_name and parsed_url.netloc.startswith("localhost") \
+ or parsed_url.netloc.startswith("127.0.0.1"):
+ self._is_localhost = True
+ self.account_name = parsed_url.path.strip("/")
+
+ self.credential = _format_shared_key_credential(self.account_name, credential)
+ if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"):
+ raise ValueError("Token credential is only supported with HTTPS.")
+
+ secondary_hostname = None
+ if hasattr(self.credential, "account_name"):
+ self.account_name = self.credential.account_name
+ secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}"
+
+ if not self._hosts:
+ if len(account) > 1:
+ secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary")
+ if kwargs.get("secondary_hostname"):
+ secondary_hostname = kwargs["secondary_hostname"]
+ primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
+ self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}
+
+ self._sdk_moniker = f"storage-{service}/{VERSION}"
+ self._config, self._pipeline = self._create_pipeline(self.credential, sdk_moniker=self._sdk_moniker, **kwargs)
+
+ def __enter__(self):
+ self._client.__enter__()
+ return self
+
+ def __exit__(self, *args):
+ self._client.__exit__(*args)
+
+ def close(self):
+ """ This method is to close the sockets opened by the client.
+ It need not be used when using with a context manager.
+ """
+ self._client.close()
+
+ @property
+ def url(self):
+ """The full endpoint URL to this entity, including SAS token if used.
+
+ This could be either the primary endpoint,
+ or the secondary endpoint depending on the current :func:`location_mode`.
+ :returns: The full endpoint URL to this entity, including SAS token if used.
+ :rtype: str
+ """
+ return self._format_url(self._hosts[self._location_mode])
+
+ @property
+ def primary_endpoint(self):
+ """The full primary endpoint URL.
+
+ :rtype: str
+ """
+ return self._format_url(self._hosts[LocationMode.PRIMARY])
+
+ @property
+ def primary_hostname(self):
+ """The hostname of the primary endpoint.
+
+ :rtype: str
+ """
+ return self._hosts[LocationMode.PRIMARY]
+
+ @property
+ def secondary_endpoint(self):
+ """The full secondary endpoint URL if configured.
+
+ If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional
+ `secondary_hostname` keyword argument on instantiation.
+
+ :rtype: str
+ :raise ValueError:
+ """
+ if not self._hosts[LocationMode.SECONDARY]:
+ raise ValueError("No secondary host configured.")
+ return self._format_url(self._hosts[LocationMode.SECONDARY])
+
+ @property
+ def secondary_hostname(self):
+ """The hostname of the secondary endpoint.
+
+ If not available this will be None. To explicitly specify a secondary hostname, use the optional
+ `secondary_hostname` keyword argument on instantiation.
+
+ :rtype: Optional[str]
+ """
+ return self._hosts[LocationMode.SECONDARY]
+
+ @property
+ def location_mode(self):
+ """The location mode that the client is currently using.
+
+ By default this will be "primary". Options include "primary" and "secondary".
+
+ :rtype: str
+ """
+
+ return self._location_mode
+
+ @location_mode.setter
+ def location_mode(self, value):
+ if self._hosts.get(value):
+ self._location_mode = value
+ self._client._config.url = self.url # pylint: disable=protected-access
+ else:
+ raise ValueError(f"No host URL for location mode: {value}")
+
+ @property
+ def api_version(self):
+ """The version of the Storage API used for requests.
+
+ :rtype: str
+ """
+ return self._client._config.version # pylint: disable=protected-access
+
+ def _format_query_string(
+ self, sas_token: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long
+ snapshot: Optional[str] = None,
+ share_snapshot: Optional[str] = None
+ ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long
+ query_str = "?"
+ if snapshot:
+ query_str += f"snapshot={snapshot}&"
+ if share_snapshot:
+ query_str += f"sharesnapshot={share_snapshot}&"
+ if sas_token and isinstance(credential, AzureSasCredential):
+ raise ValueError(
+ "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
+ if _is_credential_sastoken(credential):
+ credential = cast(str, credential)
+ query_str += credential.lstrip("?")
+ credential = None
+ elif sas_token:
+ query_str += sas_token
+ return query_str.rstrip("?&"), credential
+
+ def _create_pipeline(
+ self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> Tuple[StorageConfiguration, Pipeline]:
+ self._credential_policy: Any = None
+ if hasattr(credential, "get_token"):
+ if kwargs.get('audience'):
+ audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE
+ else:
+ audience = STORAGE_OAUTH_SCOPE
+ self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience)
+ elif isinstance(credential, SharedKeyCredentialPolicy):
+ self._credential_policy = credential
+ elif isinstance(credential, AzureSasCredential):
+ self._credential_policy = AzureSasCredentialPolicy(credential)
+ elif credential is not None:
+ raise TypeError(f"Unsupported credential: {type(credential)}")
+
+ config = kwargs.get("_configuration") or create_configuration(**kwargs)
+ if kwargs.get("_pipeline"):
+ return config, kwargs["_pipeline"]
+ transport = kwargs.get("transport")
+ kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
+ kwargs.setdefault("read_timeout", READ_TIMEOUT)
+ if not transport:
+ transport = RequestsTransport(**kwargs)
+ policies = [
+ QueueMessagePolicy(),
+ config.proxy_policy,
+ config.user_agent_policy,
+ StorageContentValidation(),
+ ContentDecodePolicy(response_encoding="utf-8"),
+ RedirectPolicy(**kwargs),
+ StorageHosts(hosts=self._hosts, **kwargs),
+ config.retry_policy,
+ config.headers_policy,
+ StorageRequestHook(**kwargs),
+ self._credential_policy,
+ config.logging_policy,
+ StorageResponseHook(**kwargs),
+ DistributedTracingPolicy(**kwargs),
+ HttpLoggingPolicy(**kwargs)
+ ]
+ if kwargs.get("_additional_pipeline_policies"):
+ policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore
+ config.transport = transport # type: ignore
+ return config, Pipeline(transport, policies=policies)
+
+ def _batch_send(
+ self,
+ *reqs: "HttpRequest",
+ **kwargs: Any
+ ) -> Iterator["HttpResponse"]:
+ """Given a series of request, do a Storage batch call.
+
+ :param HttpRequest reqs: A collection of HttpRequest objects.
+ :returns: An iterator of HttpResponse objects.
+ :rtype: Iterator[HttpResponse]
+ """
+ # Pop it here, so requests doesn't feel bad about additional kwarg
+ raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
+ batch_id = str(uuid.uuid1())
+
+ request = self._client._client.post( # pylint: disable=protected-access
+ url=(
+ f'{self.scheme}://{self.primary_hostname}/'
+ f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}"
+ f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}"
+ ),
+ headers={
+ 'x-ms-version': self.api_version,
+ "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False)
+ }
+ )
+
+ policies = [StorageHeadersPolicy()]
+ if self._credential_policy:
+ policies.append(self._credential_policy)
+
+ request.set_multipart_mixed(
+ *reqs,
+ policies=policies,
+ enforce_https=False
+ )
+
+ Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access
+ body = serialize_batch_body(request.multipart_mixed_info[0], batch_id)
+ request.set_bytes_body(body)
+
+ temp = request.multipart_mixed_info
+ request.multipart_mixed_info = None
+ pipeline_response = self._pipeline.run(
+ request, **kwargs
+ )
+ response = pipeline_response.http_response
+ request.multipart_mixed_info = temp
+
+ try:
+ if response.status_code not in [202]:
+ raise HttpResponseError(response=response)
+ parts = response.parts()
+ if raise_on_any_failure:
+ parts = list(response.parts())
+ if any(p for p in parts if not 200 <= p.status_code < 300):
+ error = PartialBatchErrorException(
+ message="There is a partial failure in the batch operation.",
+ response=response, parts=parts
+ )
+ raise error
+ return iter(parts)
+ return parts # type: ignore [no-any-return]
+ except HttpResponseError as error:
+ process_storage_error(error)
+
+
+class TransportWrapper(HttpTransport):
+ """Wrapper class that ensures that an inner client created
+ by a `get_client` method does not close the outer transport for the parent
+ when used in a context manager.
+ """
+ def __init__(self, transport):
+ self._transport = transport
+
+ def send(self, request, **kwargs):
+ return self._transport.send(request, **kwargs)
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
+
+
+def _format_shared_key_credential(
+ account_name: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long
+) -> Any:
+ if isinstance(credential, str):
+ if not account_name:
+ raise ValueError("Unable to determine account name for shared key credential.")
+ credential = {"account_name": account_name, "account_key": credential}
+ if isinstance(credential, dict):
+ if "account_name" not in credential:
+ raise ValueError("Shared key credential missing 'account_name")
+ if "account_key" not in credential:
+ raise ValueError("Shared key credential missing 'account_key")
+ return SharedKeyCredentialPolicy(**credential)
+ if isinstance(credential, AzureNamedKeyCredential):
+ return SharedKeyCredentialPolicy(credential.named_key.name, credential.named_key.key)
+ return credential
+
+
+def parse_connection_str(
+ conn_str: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]],
+ service: str
+) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long
+ conn_str = conn_str.rstrip(";")
+ conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")]
+ if any(len(tup) != 2 for tup in conn_settings_list):
+ raise ValueError("Connection string is either blank or malformed.")
+ conn_settings = dict((key.upper(), val) for key, val in conn_settings_list)
+ endpoints = _SERVICE_PARAMS[service]
+ primary = None
+ secondary = None
+ if not credential:
+ try:
+ credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]}
+ except KeyError:
+ credential = conn_settings.get("SHAREDACCESSSIGNATURE")
+ if endpoints["primary"] in conn_settings:
+ primary = conn_settings[endpoints["primary"]]
+ if endpoints["secondary"] in conn_settings:
+ secondary = conn_settings[endpoints["secondary"]]
+ else:
+ if endpoints["secondary"] in conn_settings:
+ raise ValueError("Connection string specifies only secondary endpoint.")
+ try:
+ primary =(
+ f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://"
+ f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ secondary = (
+ f"{conn_settings['ACCOUNTNAME']}-secondary."
+ f"{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ except KeyError:
+ pass
+
+ if not primary:
+ try:
+ primary = (
+ f"https://{conn_settings['ACCOUNTNAME']}."
+ f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}"
+ )
+ except KeyError as exc:
+ raise ValueError("Connection string missing required connection details.") from exc
+ if service == "dfs":
+ primary = primary.replace(".blob.", ".dfs.")
+ if secondary:
+ secondary = secondary.replace(".blob.", ".dfs.")
+ return primary, secondary, credential
+
+
+def create_configuration(**kwargs: Any) -> StorageConfiguration:
+ # Backwards compatibility if someone is not passing sdk_moniker
+ if not kwargs.get("sdk_moniker"):
+ kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}"
+ config = StorageConfiguration(**kwargs)
+ config.headers_policy = StorageHeadersPolicy(**kwargs)
+ config.user_agent_policy = UserAgentPolicy(**kwargs)
+ config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs)
+ config.logging_policy = StorageLoggingPolicy(**kwargs)
+ config.proxy_policy = ProxyPolicy(**kwargs)
+ return config
+
+
+def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]:
+ sas_values = QueryStringConstants.to_list()
+ parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
+ sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values]
+ sas_token = None
+ if sas_params:
+ sas_token = "&".join(sas_params)
+
+ snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot")
+ return snapshot, sas_token
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py
new file mode 100644
index 00000000..6186b29d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/base_client_async.py
@@ -0,0 +1,280 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# mypy: disable-error-code="attr-defined"
+
+import logging
+from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union
+
+from azure.core.async_paging import AsyncList
+from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential
+from azure.core.credentials_async import AsyncTokenCredential
+from azure.core.exceptions import HttpResponseError
+from azure.core.pipeline import AsyncPipeline
+from azure.core.pipeline.policies import (
+ AsyncRedirectPolicy,
+ AzureSasCredentialPolicy,
+ ContentDecodePolicy,
+ DistributedTracingPolicy,
+ HttpLoggingPolicy,
+)
+from azure.core.pipeline.transport import AsyncHttpTransport
+
+from .authentication import SharedKeyCredentialPolicy
+from .base_client import create_configuration
+from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE
+from .models import StorageConfiguration
+from .policies import (
+ QueueMessagePolicy,
+ StorageContentValidation,
+ StorageHeadersPolicy,
+ StorageHosts,
+ StorageRequestHook,
+)
+from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook
+from .response_handlers import PartialBatchErrorException, process_storage_error
+from .._shared_access_signature import _is_credential_sastoken
+
+if TYPE_CHECKING:
+ from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756
+_LOGGER = logging.getLogger(__name__)
+
+_SERVICE_PARAMS = {
+ "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"},
+ "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"},
+ "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"},
+ "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"},
+}
+
+
+class AsyncStorageAccountHostsMixin(object):
+
+ def __enter__(self):
+ raise TypeError("Async client only supports 'async with'.")
+
+ def __exit__(self, *args):
+ pass
+
+ async def __aenter__(self):
+ await self._client.__aenter__()
+ return self
+
+ async def __aexit__(self, *args):
+ await self._client.__aexit__(*args)
+
+ async def close(self):
+ """ This method is to close the sockets opened by the client.
+ It need not be used when using with a context manager.
+ """
+ await self._client.close()
+
+ def _format_query_string(
+ self, sas_token: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long
+ snapshot: Optional[str] = None,
+ share_snapshot: Optional[str] = None
+ ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long
+ query_str = "?"
+ if snapshot:
+ query_str += f"snapshot={snapshot}&"
+ if share_snapshot:
+ query_str += f"sharesnapshot={share_snapshot}&"
+ if sas_token and isinstance(credential, AzureSasCredential):
+ raise ValueError(
+ "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
+ if _is_credential_sastoken(credential):
+ query_str += credential.lstrip("?") # type: ignore [union-attr]
+ credential = None
+ elif sas_token:
+ query_str += sas_token
+ return query_str.rstrip("?&"), credential
+
+ def _create_pipeline(
+ self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> Tuple[StorageConfiguration, AsyncPipeline]:
+ self._credential_policy: Optional[
+ Union[AsyncStorageBearerTokenCredentialPolicy,
+ SharedKeyCredentialPolicy,
+ AzureSasCredentialPolicy]] = None
+ if hasattr(credential, 'get_token'):
+ if kwargs.get('audience'):
+ audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE
+ else:
+ audience = STORAGE_OAUTH_SCOPE
+ self._credential_policy = AsyncStorageBearerTokenCredentialPolicy(
+ cast(AsyncTokenCredential, credential), audience)
+ elif isinstance(credential, SharedKeyCredentialPolicy):
+ self._credential_policy = credential
+ elif isinstance(credential, AzureSasCredential):
+ self._credential_policy = AzureSasCredentialPolicy(credential)
+ elif credential is not None:
+ raise TypeError(f"Unsupported credential: {type(credential)}")
+ config = kwargs.get('_configuration') or create_configuration(**kwargs)
+ if kwargs.get('_pipeline'):
+ return config, kwargs['_pipeline']
+ transport = kwargs.get('transport')
+ kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
+ kwargs.setdefault("read_timeout", READ_TIMEOUT)
+ if not transport:
+ try:
+ from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import
+ except ImportError as exc:
+ raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc
+ transport = AioHttpTransport(**kwargs)
+ hosts = self._hosts
+ policies = [
+ QueueMessagePolicy(),
+ config.proxy_policy,
+ config.user_agent_policy,
+ StorageContentValidation(),
+ ContentDecodePolicy(response_encoding="utf-8"),
+ AsyncRedirectPolicy(**kwargs),
+ StorageHosts(hosts=hosts, **kwargs),
+ config.retry_policy,
+ config.headers_policy,
+ StorageRequestHook(**kwargs),
+ self._credential_policy,
+ config.logging_policy,
+ AsyncStorageResponseHook(**kwargs),
+ DistributedTracingPolicy(**kwargs),
+ HttpLoggingPolicy(**kwargs),
+ ]
+ if kwargs.get("_additional_pipeline_policies"):
+ policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore
+ config.transport = transport #type: ignore
+ return config, AsyncPipeline(transport, policies=policies) #type: ignore
+
+ async def _batch_send(
+ self,
+ *reqs: "HttpRequest",
+ **kwargs: Any
+ ) -> AsyncList["HttpResponse"]:
+ """Given a series of request, do a Storage batch call.
+
+ :param HttpRequest reqs: A collection of HttpRequest objects.
+ :returns: An AsyncList of HttpResponse objects.
+ :rtype: AsyncList[HttpResponse]
+ """
+ # Pop it here, so requests doesn't feel bad about additional kwarg
+ raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
+ request = self._client._client.post( # pylint: disable=protected-access
+ url=(
+ f'{self.scheme}://{self.primary_hostname}/'
+ f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}"
+ f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}"
+ ),
+ headers={
+ 'x-ms-version': self.api_version
+ }
+ )
+
+ policies = [StorageHeadersPolicy()]
+ if self._credential_policy:
+ policies.append(self._credential_policy) # type: ignore
+
+ request.set_multipart_mixed(
+ *reqs,
+ policies=policies,
+ enforce_https=False
+ )
+
+ pipeline_response = await self._pipeline.run(
+ request, **kwargs
+ )
+ response = pipeline_response.http_response
+
+ try:
+ if response.status_code not in [202]:
+ raise HttpResponseError(response=response)
+ parts = response.parts() # Return an AsyncIterator
+ if raise_on_any_failure:
+ parts_list = []
+ async for part in parts:
+ parts_list.append(part)
+ if any(p for p in parts_list if not 200 <= p.status_code < 300):
+ error = PartialBatchErrorException(
+ message="There is a partial failure in the batch operation.",
+ response=response, parts=parts_list
+ )
+ raise error
+ return AsyncList(parts_list)
+ return parts # type: ignore [no-any-return]
+ except HttpResponseError as error:
+ process_storage_error(error)
+
+def parse_connection_str(
+ conn_str: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]],
+ service: str
+) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long
+ conn_str = conn_str.rstrip(";")
+ conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")]
+ if any(len(tup) != 2 for tup in conn_settings_list):
+ raise ValueError("Connection string is either blank or malformed.")
+ conn_settings = dict((key.upper(), val) for key, val in conn_settings_list)
+ endpoints = _SERVICE_PARAMS[service]
+ primary = None
+ secondary = None
+ if not credential:
+ try:
+ credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]}
+ except KeyError:
+ credential = conn_settings.get("SHAREDACCESSSIGNATURE")
+ if endpoints["primary"] in conn_settings:
+ primary = conn_settings[endpoints["primary"]]
+ if endpoints["secondary"] in conn_settings:
+ secondary = conn_settings[endpoints["secondary"]]
+ else:
+ if endpoints["secondary"] in conn_settings:
+ raise ValueError("Connection string specifies only secondary endpoint.")
+ try:
+ primary =(
+ f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://"
+ f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ secondary = (
+ f"{conn_settings['ACCOUNTNAME']}-secondary."
+ f"{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ except KeyError:
+ pass
+
+ if not primary:
+ try:
+ primary = (
+ f"https://{conn_settings['ACCOUNTNAME']}."
+ f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}"
+ )
+ except KeyError as exc:
+ raise ValueError("Connection string missing required connection details.") from exc
+ if service == "dfs":
+ primary = primary.replace(".blob.", ".dfs.")
+ if secondary:
+ secondary = secondary.replace(".blob.", ".dfs.")
+ return primary, secondary, credential
+
+class AsyncTransportWrapper(AsyncHttpTransport):
+ """Wrapper class that ensures that an inner client created
+ by a `get_client` method does not close the outer transport for the parent
+ when used in a context manager.
+ """
+ def __init__(self, async_transport):
+ self._transport = async_transport
+
+ async def send(self, request, **kwargs):
+ return await self._transport.send(request, **kwargs)
+
+ async def open(self):
+ pass
+
+ async def close(self):
+ pass
+
+ async def __aenter__(self):
+ pass
+
+ async def __aexit__(self, *args):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py
new file mode 100644
index 00000000..0b4b029a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/constants.py
@@ -0,0 +1,19 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from .._serialize import _SUPPORTED_API_VERSIONS
+
+
+X_MS_VERSION = _SUPPORTED_API_VERSIONS[-1]
+
+# Default socket timeouts, in seconds
+CONNECTION_TIMEOUT = 20
+READ_TIMEOUT = 60
+
+DEFAULT_OAUTH_SCOPE = "/.default"
+STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default"
+
+SERVICE_HOST_BASE = 'core.windows.net'
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py
new file mode 100644
index 00000000..403e6b8b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/models.py
@@ -0,0 +1,585 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=too-many-instance-attributes
+from enum import Enum
+from typing import Optional
+
+from azure.core import CaseInsensitiveEnumMeta
+from azure.core.configuration import Configuration
+from azure.core.pipeline.policies import UserAgentPolicy
+
+
+def get_enum_value(value):
+ if value is None or value in ["None", ""]:
+ return None
+ try:
+ return value.value
+ except AttributeError:
+ return value
+
+
+class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta):
+
+ # Generic storage values
+ ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists"
+ ACCOUNT_BEING_CREATED = "AccountBeingCreated"
+ ACCOUNT_IS_DISABLED = "AccountIsDisabled"
+ AUTHENTICATION_FAILED = "AuthenticationFailed"
+ AUTHORIZATION_FAILURE = "AuthorizationFailure"
+ NO_AUTHENTICATION_INFORMATION = "NoAuthenticationInformation"
+ CONDITION_HEADERS_NOT_SUPPORTED = "ConditionHeadersNotSupported"
+ CONDITION_NOT_MET = "ConditionNotMet"
+ EMPTY_METADATA_KEY = "EmptyMetadataKey"
+ INSUFFICIENT_ACCOUNT_PERMISSIONS = "InsufficientAccountPermissions"
+ INTERNAL_ERROR = "InternalError"
+ INVALID_AUTHENTICATION_INFO = "InvalidAuthenticationInfo"
+ INVALID_HEADER_VALUE = "InvalidHeaderValue"
+ INVALID_HTTP_VERB = "InvalidHttpVerb"
+ INVALID_INPUT = "InvalidInput"
+ INVALID_MD5 = "InvalidMd5"
+ INVALID_METADATA = "InvalidMetadata"
+ INVALID_QUERY_PARAMETER_VALUE = "InvalidQueryParameterValue"
+ INVALID_RANGE = "InvalidRange"
+ INVALID_RESOURCE_NAME = "InvalidResourceName"
+ INVALID_URI = "InvalidUri"
+ INVALID_XML_DOCUMENT = "InvalidXmlDocument"
+ INVALID_XML_NODE_VALUE = "InvalidXmlNodeValue"
+ MD5_MISMATCH = "Md5Mismatch"
+ METADATA_TOO_LARGE = "MetadataTooLarge"
+ MISSING_CONTENT_LENGTH_HEADER = "MissingContentLengthHeader"
+ MISSING_REQUIRED_QUERY_PARAMETER = "MissingRequiredQueryParameter"
+ MISSING_REQUIRED_HEADER = "MissingRequiredHeader"
+ MISSING_REQUIRED_XML_NODE = "MissingRequiredXmlNode"
+ MULTIPLE_CONDITION_HEADERS_NOT_SUPPORTED = "MultipleConditionHeadersNotSupported"
+ OPERATION_TIMED_OUT = "OperationTimedOut"
+ OUT_OF_RANGE_INPUT = "OutOfRangeInput"
+ OUT_OF_RANGE_QUERY_PARAMETER_VALUE = "OutOfRangeQueryParameterValue"
+ REQUEST_BODY_TOO_LARGE = "RequestBodyTooLarge"
+ RESOURCE_TYPE_MISMATCH = "ResourceTypeMismatch"
+ REQUEST_URL_FAILED_TO_PARSE = "RequestUrlFailedToParse"
+ RESOURCE_ALREADY_EXISTS = "ResourceAlreadyExists"
+ RESOURCE_NOT_FOUND = "ResourceNotFound"
+ SERVER_BUSY = "ServerBusy"
+ UNSUPPORTED_HEADER = "UnsupportedHeader"
+ UNSUPPORTED_XML_NODE = "UnsupportedXmlNode"
+ UNSUPPORTED_QUERY_PARAMETER = "UnsupportedQueryParameter"
+ UNSUPPORTED_HTTP_VERB = "UnsupportedHttpVerb"
+
+ # Blob values
+ APPEND_POSITION_CONDITION_NOT_MET = "AppendPositionConditionNotMet"
+ BLOB_ACCESS_TIER_NOT_SUPPORTED_FOR_ACCOUNT_TYPE = "BlobAccessTierNotSupportedForAccountType"
+ BLOB_ALREADY_EXISTS = "BlobAlreadyExists"
+ BLOB_NOT_FOUND = "BlobNotFound"
+ BLOB_OVERWRITTEN = "BlobOverwritten"
+ BLOB_TIER_INADEQUATE_FOR_CONTENT_LENGTH = "BlobTierInadequateForContentLength"
+ BLOCK_COUNT_EXCEEDS_LIMIT = "BlockCountExceedsLimit"
+ BLOCK_LIST_TOO_LONG = "BlockListTooLong"
+ CANNOT_CHANGE_TO_LOWER_TIER = "CannotChangeToLowerTier"
+ CANNOT_VERIFY_COPY_SOURCE = "CannotVerifyCopySource"
+ CONTAINER_ALREADY_EXISTS = "ContainerAlreadyExists"
+ CONTAINER_BEING_DELETED = "ContainerBeingDeleted"
+ CONTAINER_DISABLED = "ContainerDisabled"
+ CONTAINER_NOT_FOUND = "ContainerNotFound"
+ CONTENT_LENGTH_LARGER_THAN_TIER_LIMIT = "ContentLengthLargerThanTierLimit"
+ COPY_ACROSS_ACCOUNTS_NOT_SUPPORTED = "CopyAcrossAccountsNotSupported"
+ COPY_ID_MISMATCH = "CopyIdMismatch"
+ FEATURE_VERSION_MISMATCH = "FeatureVersionMismatch"
+ INCREMENTAL_COPY_BLOB_MISMATCH = "IncrementalCopyBlobMismatch"
+ INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed"
+ #: Deprecated: Please use INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED instead.
+ INCREMENTAL_COPY_OF_ERALIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed"
+ INCREMENTAL_COPY_SOURCE_MUST_BE_SNAPSHOT = "IncrementalCopySourceMustBeSnapshot"
+ INFINITE_LEASE_DURATION_REQUIRED = "InfiniteLeaseDurationRequired"
+ INVALID_BLOB_OR_BLOCK = "InvalidBlobOrBlock"
+ INVALID_BLOB_TIER = "InvalidBlobTier"
+ INVALID_BLOB_TYPE = "InvalidBlobType"
+ INVALID_BLOCK_ID = "InvalidBlockId"
+ INVALID_BLOCK_LIST = "InvalidBlockList"
+ INVALID_OPERATION = "InvalidOperation"
+ INVALID_PAGE_RANGE = "InvalidPageRange"
+ INVALID_SOURCE_BLOB_TYPE = "InvalidSourceBlobType"
+ INVALID_SOURCE_BLOB_URL = "InvalidSourceBlobUrl"
+ INVALID_VERSION_FOR_PAGE_BLOB_OPERATION = "InvalidVersionForPageBlobOperation"
+ LEASE_ALREADY_PRESENT = "LeaseAlreadyPresent"
+ LEASE_ALREADY_BROKEN = "LeaseAlreadyBroken"
+ LEASE_ID_MISMATCH_WITH_BLOB_OPERATION = "LeaseIdMismatchWithBlobOperation"
+ LEASE_ID_MISMATCH_WITH_CONTAINER_OPERATION = "LeaseIdMismatchWithContainerOperation"
+ LEASE_ID_MISMATCH_WITH_LEASE_OPERATION = "LeaseIdMismatchWithLeaseOperation"
+ LEASE_ID_MISSING = "LeaseIdMissing"
+ LEASE_IS_BREAKING_AND_CANNOT_BE_ACQUIRED = "LeaseIsBreakingAndCannotBeAcquired"
+ LEASE_IS_BREAKING_AND_CANNOT_BE_CHANGED = "LeaseIsBreakingAndCannotBeChanged"
+ LEASE_IS_BROKEN_AND_CANNOT_BE_RENEWED = "LeaseIsBrokenAndCannotBeRenewed"
+ LEASE_LOST = "LeaseLost"
+ LEASE_NOT_PRESENT_WITH_BLOB_OPERATION = "LeaseNotPresentWithBlobOperation"
+ LEASE_NOT_PRESENT_WITH_CONTAINER_OPERATION = "LeaseNotPresentWithContainerOperation"
+ LEASE_NOT_PRESENT_WITH_LEASE_OPERATION = "LeaseNotPresentWithLeaseOperation"
+ MAX_BLOB_SIZE_CONDITION_NOT_MET = "MaxBlobSizeConditionNotMet"
+ NO_PENDING_COPY_OPERATION = "NoPendingCopyOperation"
+ OPERATION_NOT_ALLOWED_ON_INCREMENTAL_COPY_BLOB = "OperationNotAllowedOnIncrementalCopyBlob"
+ PENDING_COPY_OPERATION = "PendingCopyOperation"
+ PREVIOUS_SNAPSHOT_CANNOT_BE_NEWER = "PreviousSnapshotCannotBeNewer"
+ PREVIOUS_SNAPSHOT_NOT_FOUND = "PreviousSnapshotNotFound"
+ PREVIOUS_SNAPSHOT_OPERATION_NOT_SUPPORTED = "PreviousSnapshotOperationNotSupported"
+ SEQUENCE_NUMBER_CONDITION_NOT_MET = "SequenceNumberConditionNotMet"
+ SEQUENCE_NUMBER_INCREMENT_TOO_LARGE = "SequenceNumberIncrementTooLarge"
+ SNAPSHOT_COUNT_EXCEEDED = "SnapshotCountExceeded"
+ SNAPSHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded"
+ #: Deprecated: Please use SNAPSHOT_OPERATION_RATE_EXCEEDED instead.
+ SNAPHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded"
+ SNAPSHOTS_PRESENT = "SnapshotsPresent"
+ SOURCE_CONDITION_NOT_MET = "SourceConditionNotMet"
+ SYSTEM_IN_USE = "SystemInUse"
+ TARGET_CONDITION_NOT_MET = "TargetConditionNotMet"
+ UNAUTHORIZED_BLOB_OVERWRITE = "UnauthorizedBlobOverwrite"
+ BLOB_BEING_REHYDRATED = "BlobBeingRehydrated"
+ BLOB_ARCHIVED = "BlobArchived"
+ BLOB_NOT_ARCHIVED = "BlobNotArchived"
+
+ # Queue values
+ INVALID_MARKER = "InvalidMarker"
+ MESSAGE_NOT_FOUND = "MessageNotFound"
+ MESSAGE_TOO_LARGE = "MessageTooLarge"
+ POP_RECEIPT_MISMATCH = "PopReceiptMismatch"
+ QUEUE_ALREADY_EXISTS = "QueueAlreadyExists"
+ QUEUE_BEING_DELETED = "QueueBeingDeleted"
+ QUEUE_DISABLED = "QueueDisabled"
+ QUEUE_NOT_EMPTY = "QueueNotEmpty"
+ QUEUE_NOT_FOUND = "QueueNotFound"
+
+ # File values
+ CANNOT_DELETE_FILE_OR_DIRECTORY = "CannotDeleteFileOrDirectory"
+ CLIENT_CACHE_FLUSH_DELAY = "ClientCacheFlushDelay"
+ DELETE_PENDING = "DeletePending"
+ DIRECTORY_NOT_EMPTY = "DirectoryNotEmpty"
+ FILE_LOCK_CONFLICT = "FileLockConflict"
+ FILE_SHARE_PROVISIONED_BANDWIDTH_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedBandwidthDowngradeNotAllowed"
+ FILE_SHARE_PROVISIONED_IOPS_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedIopsDowngradeNotAllowed"
+ INVALID_FILE_OR_DIRECTORY_PATH_NAME = "InvalidFileOrDirectoryPathName"
+ PARENT_NOT_FOUND = "ParentNotFound"
+ READ_ONLY_ATTRIBUTE = "ReadOnlyAttribute"
+ SHARE_ALREADY_EXISTS = "ShareAlreadyExists"
+ SHARE_BEING_DELETED = "ShareBeingDeleted"
+ SHARE_DISABLED = "ShareDisabled"
+ SHARE_NOT_FOUND = "ShareNotFound"
+ SHARING_VIOLATION = "SharingViolation"
+ SHARE_SNAPSHOT_IN_PROGRESS = "ShareSnapshotInProgress"
+ SHARE_SNAPSHOT_COUNT_EXCEEDED = "ShareSnapshotCountExceeded"
+ SHARE_SNAPSHOT_OPERATION_NOT_SUPPORTED = "ShareSnapshotOperationNotSupported"
+ SHARE_HAS_SNAPSHOTS = "ShareHasSnapshots"
+ CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed"
+
+ # DataLake values
+ CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero'
+ PATH_ALREADY_EXISTS = 'PathAlreadyExists'
+ INVALID_FLUSH_POSITION = 'InvalidFlushPosition'
+ INVALID_PROPERTY_NAME = 'InvalidPropertyName'
+ INVALID_SOURCE_URI = 'InvalidSourceUri'
+ UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion'
+ FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound'
+ PATH_NOT_FOUND = 'PathNotFound'
+ RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound'
+ SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound'
+ DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted'
+ FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists'
+ FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted'
+ INVALID_DESTINATION_PATH = 'InvalidDestinationPath'
+ INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath'
+ INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType'
+ LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken'
+ LEASE_NAME_MISMATCH = 'LeaseNameMismatch'
+ PATH_CONFLICT = 'PathConflict'
+ SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted'
+
+
+class DictMixin(object):
+
+ def __setitem__(self, key, item):
+ self.__dict__[key] = item
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __repr__(self):
+ return str(self)
+
+ def __len__(self):
+ return len(self.keys())
+
+ def __delitem__(self, key):
+ self.__dict__[key] = None
+
+ # Compare objects by comparing all attributes.
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ return False
+
+ # Compare objects by comparing all attributes.
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')})
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def has_key(self, k):
+ return k in self.__dict__
+
+ def update(self, *args, **kwargs):
+ return self.__dict__.update(*args, **kwargs)
+
+ def keys(self):
+ return [k for k in self.__dict__ if not k.startswith('_')]
+
+ def values(self):
+ return [v for k, v in self.__dict__.items() if not k.startswith('_')]
+
+ def items(self):
+ return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')]
+
+ def get(self, key, default=None):
+ if key in self.__dict__:
+ return self.__dict__[key]
+ return default
+
+
+class LocationMode(object):
+ """
+ Specifies the location the request should be sent to. This mode only applies
+ for RA-GRS accounts which allow secondary read access. All other account types
+ must use PRIMARY.
+ """
+
+ PRIMARY = 'primary' #: Requests should be sent to the primary location.
+ SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible.
+
+
+class ResourceTypes(object):
+ """
+ Specifies the resource types that are accessible with the account SAS.
+
+ :param bool service:
+ Access to service-level APIs (e.g., Get/Set Service Properties,
+ Get Service Stats, List Containers/Queues/Shares)
+ :param bool container:
+ Access to container-level APIs (e.g., Create/Delete Container,
+ Create/Delete Queue, Create/Delete Share,
+ List Blobs/Files and Directories)
+ :param bool object:
+ Access to object-level APIs for blobs, queue messages, and
+ files(e.g. Put Blob, Query Entity, Get Messages, Create File, etc.)
+ """
+
+ service: bool = False
+ container: bool = False
+ object: bool = False
+ _str: str
+
+ def __init__(
+ self,
+ service: bool = False,
+ container: bool = False,
+ object: bool = False # pylint: disable=redefined-builtin
+ ) -> None:
+ self.service = service
+ self.container = container
+ self.object = object
+ self._str = (('s' if self.service else '') +
+ ('c' if self.container else '') +
+ ('o' if self.object else ''))
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, string):
+ """Create a ResourceTypes from a string.
+
+ To specify service, container, or object you need only to
+ include the first letter of the word in the string. E.g. service and container,
+ you would provide a string "sc".
+
+ :param str string: Specify service, container, or object in
+ in the string with the first letter of the word.
+ :return: A ResourceTypes object
+ :rtype: ~azure.storage.fileshare.ResourceTypes
+ """
+ res_service = 's' in string
+ res_container = 'c' in string
+ res_object = 'o' in string
+
+ parsed = cls(res_service, res_container, res_object)
+ parsed._str = string
+ return parsed
+
+
+class AccountSasPermissions(object):
+ """
+ :class:`~ResourceTypes` class to be used with generate_account_sas
+ function and for the AccessPolicies used with set_*_acl. There are two types of
+ SAS which may be used to grant resource access. One is to grant access to a
+ specific resource (resource-specific). Another is to grant access to the
+ entire service for a specific account and allow certain operations based on
+ perms found here.
+
+ :param bool read:
+ Valid for all signed resources types (Service, Container, and Object).
+ Permits read permissions to the specified resource type.
+ :param bool write:
+ Valid for all signed resources types (Service, Container, and Object).
+ Permits write permissions to the specified resource type.
+ :param bool delete:
+ Valid for Container and Object resource types, except for queue messages.
+ :param bool delete_previous_version:
+ Delete the previous blob version for the versioning enabled storage account.
+ :param bool list:
+ Valid for Service and Container resource types only.
+ :param bool add:
+ Valid for the following Object resource types only: queue messages, and append blobs.
+ :param bool create:
+ Valid for the following Object resource types only: blobs and files.
+ Users can create new blobs or files, but may not overwrite existing
+ blobs or files.
+ :param bool update:
+ Valid for the following Object resource types only: queue messages.
+ :param bool process:
+ Valid for the following Object resource type only: queue messages.
+ :keyword bool tag:
+ To enable set or get tags on the blobs in the container.
+ :keyword bool filter_by_tags:
+ To enable get blobs by tags, this should be used together with list permission.
+ :keyword bool set_immutability_policy:
+ To enable operations related to set/delete immutability policy.
+ To get immutability policy, you just need read permission.
+ :keyword bool permanent_delete:
+ To enable permanent delete on the blob is permitted.
+ Valid for Object resource type of Blob only.
+ """
+
+ read: bool = False
+ write: bool = False
+ delete: bool = False
+ delete_previous_version: bool = False
+ list: bool = False
+ add: bool = False
+ create: bool = False
+ update: bool = False
+ process: bool = False
+ tag: bool = False
+ filter_by_tags: bool = False
+ set_immutability_policy: bool = False
+ permanent_delete: bool = False
+
+ def __init__(
+ self,
+ read: bool = False,
+ write: bool = False,
+ delete: bool = False,
+ list: bool = False, # pylint: disable=redefined-builtin
+ add: bool = False,
+ create: bool = False,
+ update: bool = False,
+ process: bool = False,
+ delete_previous_version: bool = False,
+ **kwargs
+ ) -> None:
+ self.read = read
+ self.write = write
+ self.delete = delete
+ self.delete_previous_version = delete_previous_version
+ self.permanent_delete = kwargs.pop('permanent_delete', False)
+ self.list = list
+ self.add = add
+ self.create = create
+ self.update = update
+ self.process = process
+ self.tag = kwargs.pop('tag', False)
+ self.filter_by_tags = kwargs.pop('filter_by_tags', False)
+ self.set_immutability_policy = kwargs.pop('set_immutability_policy', False)
+ self._str = (('r' if self.read else '') +
+ ('w' if self.write else '') +
+ ('d' if self.delete else '') +
+ ('x' if self.delete_previous_version else '') +
+ ('y' if self.permanent_delete else '') +
+ ('l' if self.list else '') +
+ ('a' if self.add else '') +
+ ('c' if self.create else '') +
+ ('u' if self.update else '') +
+ ('p' if self.process else '') +
+ ('f' if self.filter_by_tags else '') +
+ ('t' if self.tag else '') +
+ ('i' if self.set_immutability_policy else '')
+ )
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, permission):
+ """Create AccountSasPermissions from a string.
+
+ To specify read, write, delete, etc. permissions you need only to
+ include the first letter of the word in the string. E.g. for read and write
+ permissions you would provide a string "rw".
+
+ :param str permission: Specify permissions in
+ the string with the first letter of the word.
+ :return: An AccountSasPermissions object
+ :rtype: ~azure.storage.fileshare.AccountSasPermissions
+ """
+ p_read = 'r' in permission
+ p_write = 'w' in permission
+ p_delete = 'd' in permission
+ p_delete_previous_version = 'x' in permission
+ p_permanent_delete = 'y' in permission
+ p_list = 'l' in permission
+ p_add = 'a' in permission
+ p_create = 'c' in permission
+ p_update = 'u' in permission
+ p_process = 'p' in permission
+ p_tag = 't' in permission
+ p_filter_by_tags = 'f' in permission
+ p_set_immutability_policy = 'i' in permission
+ parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version,
+ list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag,
+ filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy,
+ permanent_delete=p_permanent_delete)
+
+ return parsed
+
+
+class Services(object):
+ """Specifies the services accessible with the account SAS.
+
+ :keyword bool blob:
+ Access for the `~azure.storage.blob.BlobServiceClient`. Default is False.
+ :keyword bool queue:
+ Access for the `~azure.storage.queue.QueueServiceClient`. Default is False.
+ :keyword bool fileshare:
+ Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False.
+ """
+
+ def __init__(
+ self, *,
+ blob: bool = False,
+ queue: bool = False,
+ fileshare: bool = False
+ ) -> None:
+ self.blob = blob
+ self.queue = queue
+ self.fileshare = fileshare
+ self._str = (('b' if self.blob else '') +
+ ('q' if self.queue else '') +
+ ('f' if self.fileshare else ''))
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, string):
+ """Create Services from a string.
+
+ To specify blob, queue, or file you need only to
+ include the first letter of the word in the string. E.g. for blob and queue
+ you would provide a string "bq".
+
+ :param str string: Specify blob, queue, or file in
+ in the string with the first letter of the word.
+ :return: A Services object
+ :rtype: ~azure.storage.fileshare.Services
+ """
+ res_blob = 'b' in string
+ res_queue = 'q' in string
+ res_file = 'f' in string
+
+ parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file)
+ parsed._str = string
+ return parsed
+
+
+class UserDelegationKey(object):
+ """
+ Represents a user delegation key, provided to the user by Azure Storage
+ based on their Azure Active Directory access token.
+
+ The fields are saved as simple strings since the user does not have to interact with this object;
+ to generate an identify SAS, the user can simply pass it to the right API.
+ """
+
+ signed_oid: Optional[str] = None
+ """Object ID of this token."""
+ signed_tid: Optional[str] = None
+ """Tenant ID of the tenant that issued this token."""
+ signed_start: Optional[str] = None
+ """The datetime this token becomes valid."""
+ signed_expiry: Optional[str] = None
+ """The datetime this token expires."""
+ signed_service: Optional[str] = None
+ """What service this key is valid for."""
+ signed_version: Optional[str] = None
+ """The version identifier of the REST service that created this token."""
+ value: Optional[str] = None
+ """The user delegation key."""
+
+ def __init__(self):
+ self.signed_oid = None
+ self.signed_tid = None
+ self.signed_start = None
+ self.signed_expiry = None
+ self.signed_service = None
+ self.signed_version = None
+ self.value = None
+
+
+class StorageConfiguration(Configuration):
+ """
+ Specifies the configurable values used in Azure Storage.
+
+ :param int max_single_put_size: If the blob size is less than or equal max_single_put_size, then the blob will be
+ uploaded with only one http PUT request. If the blob size is larger than max_single_put_size,
+ the blob will be uploaded in chunks. Defaults to 64*1024*1024, or 64MB.
+ :param int copy_polling_interval: The interval in seconds for polling copy operations.
+ :param int max_block_size: The maximum chunk size for uploading a block blob in chunks.
+ Defaults to 4*1024*1024, or 4MB.
+ :param int min_large_block_upload_threshold: The minimum chunk size required to use the memory efficient
+ algorithm when uploading a block blob.
+ :param bool use_byte_buffer: Use a byte buffer for block blob uploads. Defaults to False.
+ :param int max_page_size: The maximum chunk size for uploading a page blob. Defaults to 4*1024*1024, or 4MB.
+ :param int min_large_chunk_upload_threshold: The max size for a single put operation.
+ :param int max_single_get_size: The maximum size for a blob to be downloaded in a single call,
+ the exceeded part will be downloaded in chunks (could be parallel). Defaults to 32*1024*1024, or 32MB.
+ :param int max_chunk_get_size: The maximum chunk size used for downloading a blob. Defaults to 4*1024*1024,
+ or 4MB.
+ :param int max_range_size: The max range size for file upload.
+
+ """
+
+ max_single_put_size: int
+ copy_polling_interval: int
+ max_block_size: int
+ min_large_block_upload_threshold: int
+ use_byte_buffer: bool
+ max_page_size: int
+ min_large_chunk_upload_threshold: int
+ max_single_get_size: int
+ max_chunk_get_size: int
+ max_range_size: int
+ user_agent_policy: UserAgentPolicy
+
+ def __init__(self, **kwargs):
+ super(StorageConfiguration, self).__init__(**kwargs)
+ self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024)
+ self.copy_polling_interval = 15
+ self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024)
+ self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1)
+ self.use_byte_buffer = kwargs.pop('use_byte_buffer', False)
+ self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024)
+ self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1)
+ self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024)
+ self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024)
+ self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py
new file mode 100644
index 00000000..112c1984
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/parser.py
@@ -0,0 +1,53 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from datetime import datetime, timezone
+from typing import Optional
+
+EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime
+HUNDREDS_OF_NANOSECONDS = 10000000
+
+
+def _to_utc_datetime(value: datetime) -> str:
+ return value.strftime('%Y-%m-%dT%H:%M:%SZ')
+
+
+def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]:
+ """Converts an RFC 1123 date string to a UTC datetime.
+
+ :param str rfc_1123: The time and date in RFC 1123 format.
+ :returns: The time and date in UTC datetime format.
+ :rtype: datetime
+ """
+ if not rfc_1123:
+ return None
+
+ return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z")
+
+
+def _filetime_to_datetime(filetime: str) -> Optional[datetime]:
+ """Converts an MS filetime string to a UTC datetime. "0" indicates None.
+ If parsing MS Filetime fails, tries RFC 1123 as backup.
+
+ :param str filetime: The time and date in MS filetime format.
+ :returns: The time and date in UTC datetime format.
+ :rtype: datetime
+ """
+ if not filetime:
+ return None
+
+ # Try to convert to MS Filetime
+ try:
+ temp_filetime = int(filetime)
+ if temp_filetime == 0:
+ return None
+
+ return datetime.fromtimestamp((temp_filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc)
+ except ValueError:
+ pass
+
+ # Try RFC 1123 as backup
+ return _rfc_1123_to_datetime(filetime)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py
new file mode 100644
index 00000000..ee75cd5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies.py
@@ -0,0 +1,694 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import base64
+import hashlib
+import logging
+import random
+import re
+import uuid
+from io import SEEK_SET, UnsupportedOperation
+from time import time
+from typing import Any, Dict, Optional, TYPE_CHECKING
+from urllib.parse import (
+ parse_qsl,
+ urlencode,
+ urlparse,
+ urlunparse,
+)
+from wsgiref.handlers import format_date_time
+
+from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
+from azure.core.pipeline.policies import (
+ BearerTokenCredentialPolicy,
+ HeadersPolicy,
+ HTTPPolicy,
+ NetworkTraceLoggingPolicy,
+ RequestHistory,
+ SansIOHTTPPolicy
+)
+
+from .authentication import AzureSigningError, StorageHttpChallenge
+from .constants import DEFAULT_OAUTH_SCOPE
+from .models import LocationMode
+
+if TYPE_CHECKING:
+ from azure.core.credentials import TokenCredential
+ from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
+ PipelineRequest,
+ PipelineResponse
+ )
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def encode_base64(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ encoded = base64.b64encode(data)
+ return encoded.decode('utf-8')
+
+
+# Are we out of retries?
+def is_exhausted(settings):
+ retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status'])
+ retry_counts = list(filter(None, retry_counts))
+ if not retry_counts:
+ return False
+ return min(retry_counts) < 0
+
+
+def retry_hook(settings, **kwargs):
+ if settings['hook']:
+ settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs)
+
+
+# Is this method/status code retryable? (Based on allowlists and control
+# variables such as the number of total retries to allow, whether to
+# respect the Retry-After header, whether this header is present, and
+# whether the returned status code is on the list of status codes to
+# be retried upon on the presence of the aforementioned header)
+def is_retry(response, mode):
+ status = response.http_response.status_code
+ if 300 <= status < 500:
+ # An exception occurred, but in most cases it was expected. Examples could
+ # include a 309 Conflict or 412 Precondition Failed.
+ if status == 404 and mode == LocationMode.SECONDARY:
+ # Response code 404 should be retried if secondary was used.
+ return True
+ if status == 408:
+ # Response code 408 is a timeout and should be retried.
+ return True
+ return False
+ if status >= 500:
+ # Response codes above 500 with the exception of 501 Not Implemented and
+ # 505 Version Not Supported indicate a server issue and should be retried.
+ if status in [501, 505]:
+ return False
+ return True
+ return False
+
+
+def is_checksum_retry(response):
+ # retry if invalid content md5
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ computed_md5 = response.http_request.headers.get('content-md5', None) or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ return True
+ return False
+
+
+def urljoin(base_url, stub_url):
+ parsed = urlparse(base_url)
+ parsed = parsed._replace(path=parsed.path + '/' + stub_url)
+ return parsed.geturl()
+
+
+class QueueMessagePolicy(SansIOHTTPPolicy):
+
+ def on_request(self, request):
+ message_id = request.context.options.pop('queue_message_id', None)
+ if message_id:
+ request.http_request.url = urljoin(
+ request.http_request.url,
+ message_id)
+
+
+class StorageHeadersPolicy(HeadersPolicy):
+ request_id_header_name = 'x-ms-client-request-id'
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ super(StorageHeadersPolicy, self).on_request(request)
+ current_time = format_date_time(time())
+ request.http_request.headers['x-ms-date'] = current_time
+
+ custom_id = request.context.options.pop('client_request_id', None)
+ request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1())
+
+ # def on_response(self, request, response):
+ # # raise exception if the echoed client request id from the service is not identical to the one we sent
+ # if self.request_id_header_name in response.http_response.headers:
+
+ # client_request_id = request.http_request.headers.get(self.request_id_header_name)
+
+ # if response.http_response.headers[self.request_id_header_name] != client_request_id:
+ # raise AzureError(
+ # "Echoed client request ID: {} does not match sent client request ID: {}. "
+ # "Service request ID: {}".format(
+ # response.http_response.headers[self.request_id_header_name], client_request_id,
+ # response.http_response.headers['x-ms-request-id']),
+ # response=response.http_response
+ # )
+
+
+class StorageHosts(SansIOHTTPPolicy):
+
+ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument
+ self.hosts = hosts
+ super(StorageHosts, self).__init__()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ request.context.options['hosts'] = self.hosts
+ parsed_url = urlparse(request.http_request.url)
+
+ # Detect what location mode we're currently requesting with
+ location_mode = LocationMode.PRIMARY
+ for key, value in self.hosts.items():
+ if parsed_url.netloc == value:
+ location_mode = key
+
+ # See if a specific location mode has been specified, and if so, redirect
+ use_location = request.context.options.pop('use_location', None)
+ if use_location:
+ # Lock retries to the specific location
+ request.context.options['retry_to_secondary'] = False
+ if use_location not in self.hosts:
+ raise ValueError(f"Attempting to use undefined host location {use_location}")
+ if use_location != location_mode:
+ # Update request URL to use the specified location
+ updated = parsed_url._replace(netloc=self.hosts[use_location])
+ request.http_request.url = updated.geturl()
+ location_mode = use_location
+
+ request.context.options['location_mode'] = location_mode
+
+
+class StorageLoggingPolicy(NetworkTraceLoggingPolicy):
+ """A policy that logs HTTP request and response to the DEBUG logger.
+
+ This accepts both global configuration, and per-request level with "enable_http_logger"
+ """
+
+ def __init__(self, logging_enable: bool = False, **kwargs) -> None:
+ self.logging_body = kwargs.pop("logging_body", False)
+ super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs)
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ http_request = request.http_request
+ options = request.context.options
+ self.logging_body = self.logging_body or options.pop("logging_body", False)
+ if options.pop("logging_enable", self.enable_http_logger):
+ request.context["logging_enable"] = True
+ if not _LOGGER.isEnabledFor(logging.DEBUG):
+ return
+
+ try:
+ log_url = http_request.url
+ query_params = http_request.query
+ if 'sig' in query_params:
+ log_url = log_url.replace(query_params['sig'], "sig=*****")
+ _LOGGER.debug("Request URL: %r", log_url)
+ _LOGGER.debug("Request method: %r", http_request.method)
+ _LOGGER.debug("Request headers:")
+ for header, value in http_request.headers.items():
+ if header.lower() == 'authorization':
+ value = '*****'
+ elif header.lower() == 'x-ms-copy-source' and 'sig' in value:
+ # take the url apart and scrub away the signed signature
+ scheme, netloc, path, params, query, fragment = urlparse(value)
+ parsed_qs = dict(parse_qsl(query))
+ parsed_qs['sig'] = '*****'
+
+ # the SAS needs to be put back together
+ value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment))
+
+ _LOGGER.debug(" %r: %r", header, value)
+ _LOGGER.debug("Request body:")
+
+ if self.logging_body:
+ _LOGGER.debug(str(http_request.body))
+ else:
+ # We don't want to log the binary data of a file upload.
+ _LOGGER.debug("Hidden body, please use logging_body to show body")
+ except Exception as err: # pylint: disable=broad-except
+ _LOGGER.debug("Failed to log request: %r", err)
+
+ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+ if response.context.pop("logging_enable", self.enable_http_logger):
+ if not _LOGGER.isEnabledFor(logging.DEBUG):
+ return
+
+ try:
+ _LOGGER.debug("Response status: %r", response.http_response.status_code)
+ _LOGGER.debug("Response headers:")
+ for res_header, value in response.http_response.headers.items():
+ _LOGGER.debug(" %r: %r", res_header, value)
+
+ # We don't want to log binary data if the response is a file.
+ _LOGGER.debug("Response content:")
+ pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
+ header = response.http_response.headers.get('content-disposition')
+ resp_content_type = response.http_response.headers.get("content-type", "")
+
+ if header and pattern.match(header):
+ filename = header.partition('=')[2]
+ _LOGGER.debug("File attachments: %s", filename)
+ elif resp_content_type.endswith("octet-stream"):
+ _LOGGER.debug("Body contains binary data.")
+ elif resp_content_type.startswith("image"):
+ _LOGGER.debug("Body contains image data.")
+
+ if self.logging_body and resp_content_type.startswith("text"):
+ _LOGGER.debug(response.http_response.text())
+ elif self.logging_body:
+ try:
+ _LOGGER.debug(response.http_response.body())
+ except ValueError:
+ _LOGGER.debug("Body is streamable")
+
+ except Exception as err: # pylint: disable=broad-except
+ _LOGGER.debug("Failed to log response: %s", repr(err))
+
+
+class StorageRequestHook(SansIOHTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._request_callback = kwargs.get('raw_request_hook')
+ super(StorageRequestHook, self).__init__()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ request_callback = request.context.options.pop('raw_request_hook', self._request_callback)
+ if request_callback:
+ request_callback(request)
+
+
+class StorageResponseHook(HTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._response_callback = kwargs.get('raw_response_hook')
+ super(StorageResponseHook, self).__init__()
+
+ def send(self, request: "PipelineRequest") -> "PipelineResponse":
+ # Values could be 0
+ data_stream_total = request.context.get('data_stream_total')
+ if data_stream_total is None:
+ data_stream_total = request.context.options.pop('data_stream_total', None)
+ download_stream_current = request.context.get('download_stream_current')
+ if download_stream_current is None:
+ download_stream_current = request.context.options.pop('download_stream_current', None)
+ upload_stream_current = request.context.get('upload_stream_current')
+ if upload_stream_current is None:
+ upload_stream_current = request.context.options.pop('upload_stream_current', None)
+
+ response_callback = request.context.get('response_callback') or \
+ request.context.options.pop('raw_response_hook', self._response_callback)
+
+ response = self.next.send(request)
+
+ will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response)
+ # Auth error could come from Bearer challenge, in which case this request will be made again
+ is_auth_error = response.http_response.status_code == 401
+ should_update_counts = not (will_retry or is_auth_error)
+
+ if should_update_counts and download_stream_current is not None:
+ download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
+ if data_stream_total is None:
+ content_range = response.http_response.headers.get('Content-Range')
+ if content_range:
+ data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
+ else:
+ data_stream_total = download_stream_current
+ elif should_update_counts and upload_stream_current is not None:
+ upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
+ for pipeline_obj in [request, response]:
+ if hasattr(pipeline_obj, 'context'):
+ pipeline_obj.context['data_stream_total'] = data_stream_total
+ pipeline_obj.context['download_stream_current'] = download_stream_current
+ pipeline_obj.context['upload_stream_current'] = upload_stream_current
+ if response_callback:
+ response_callback(response)
+ request.context['response_callback'] = response_callback
+ return response
+
+
+class StorageContentValidation(SansIOHTTPPolicy):
+ """A simple policy that sends the given headers
+ with the request.
+
+ This will overwrite any headers already defined in the request.
+ """
+ header_name = 'Content-MD5'
+
+ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument
+ super(StorageContentValidation, self).__init__()
+
+ @staticmethod
+ def get_content_md5(data):
+ # Since HTTP does not differentiate between no content and empty content,
+ # we have to perform a None check.
+ data = data or b""
+ md5 = hashlib.md5() # nosec
+ if isinstance(data, bytes):
+ md5.update(data)
+ elif hasattr(data, 'read'):
+ pos = 0
+ try:
+ pos = data.tell()
+ except: # pylint: disable=bare-except
+ pass
+ for chunk in iter(lambda: data.read(4096), b""):
+ md5.update(chunk)
+ try:
+ data.seek(pos, SEEK_SET)
+ except (AttributeError, IOError) as exc:
+ raise ValueError("Data should be bytes or a seekable file-like object.") from exc
+ else:
+ raise ValueError("Data should be bytes or a seekable file-like object.")
+
+ return md5.digest()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ validate_content = request.context.options.pop('validate_content', False)
+ if validate_content and request.http_request.method != 'GET':
+ computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data))
+ request.http_request.headers[self.header_name] = computed_md5
+ request.context['validate_content_md5'] = computed_md5
+ request.context['validate_content'] = validate_content
+
+ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ computed_md5 = request.context.get('validate_content_md5') or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ raise AzureError((
+ f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', "
+ f"computed value is '{computed_md5}'."),
+ response=response.http_response
+ )
+
+
+class StorageRetryPolicy(HTTPPolicy):
+ """
+ The base class for Exponential and Linear retries containing shared code.
+ """
+
+ total_retries: int
+ """The max number of retries."""
+ connect_retries: int
+ """The max number of connect retries."""
+ retry_read: int
+ """The max number of read retries."""
+ retry_status: int
+ """The max number of status retries."""
+ retry_to_secondary: bool
+ """Whether the secondary endpoint should be retried."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.total_retries = kwargs.pop('retry_total', 10)
+ self.connect_retries = kwargs.pop('retry_connect', 3)
+ self.read_retries = kwargs.pop('retry_read', 3)
+ self.status_retries = kwargs.pop('retry_status', 3)
+ self.retry_to_secondary = kwargs.pop('retry_to_secondary', False)
+ super(StorageRetryPolicy, self).__init__()
+
+ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None:
+ """
+ A function which sets the next host location on the request, if applicable.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the next host location.
+ :param PipelineRequest request: A pipeline request object.
+ """
+ if settings['hosts'] and all(settings['hosts'].values()):
+ url = urlparse(request.url)
+ # If there's more than one possible location, retry to the alternative
+ if settings['mode'] == LocationMode.PRIMARY:
+ settings['mode'] = LocationMode.SECONDARY
+ else:
+ settings['mode'] = LocationMode.PRIMARY
+ updated = url._replace(netloc=settings['hosts'].get(settings['mode']))
+ request.url = updated.geturl()
+
+ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]:
+ body_position = None
+ if hasattr(request.http_request.body, 'read'):
+ try:
+ body_position = request.http_request.body.tell()
+ except (AttributeError, UnsupportedOperation):
+ # if body position cannot be obtained, then retries will not work
+ pass
+ options = request.context.options
+ return {
+ 'total': options.pop("retry_total", self.total_retries),
+ 'connect': options.pop("retry_connect", self.connect_retries),
+ 'read': options.pop("retry_read", self.read_retries),
+ 'status': options.pop("retry_status", self.status_retries),
+ 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary),
+ 'mode': options.pop("location_mode", LocationMode.PRIMARY),
+ 'hosts': options.pop("hosts", None),
+ 'hook': options.pop("retry_hook", None),
+ 'body_position': body_position,
+ 'count': 0,
+ 'history': []
+ }
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument
+ """ Formula for computing the current backoff.
+ Should be calculated by child class.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
+ :returns: The backoff time.
+ :rtype: float
+ """
+ return 0
+
+ def sleep(self, settings, transport):
+ backoff = self.get_backoff_time(settings)
+ if not backoff or backoff < 0:
+ return
+ transport.sleep(backoff)
+
+ def increment(
+ self, settings: Dict[str, Any],
+ request: "PipelineRequest",
+ response: Optional["PipelineResponse"] = None,
+ error: Optional[AzureError] = None
+ ) -> bool:
+ """Increment the retry counters.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the increment operation.
+ :param PipelineRequest request: A pipeline request object.
+ :param Optional[PipelineResponse] response: A pipeline response object.
+ :param Optional[AzureError] error: An error encountered during the request, or
+ None if the response was received successfully.
+ :returns: Whether the retry attempts are exhausted.
+ :rtype: bool
+ """
+ settings['total'] -= 1
+
+ if error and isinstance(error, ServiceRequestError):
+ # Errors when we're fairly sure that the server did not receive the
+ # request, so it should be safe to retry.
+ settings['connect'] -= 1
+ settings['history'].append(RequestHistory(request, error=error))
+
+ elif error and isinstance(error, ServiceResponseError):
+ # Errors that occur after the request has been started, so we should
+ # assume that the server began processing it.
+ settings['read'] -= 1
+ settings['history'].append(RequestHistory(request, error=error))
+
+ else:
+ # Incrementing because of a server error like a 500 in
+ # status_forcelist and a the given method is in the allowlist
+ if response:
+ settings['status'] -= 1
+ settings['history'].append(RequestHistory(request, http_response=response))
+
+ if not is_exhausted(settings):
+ if request.method not in ['PUT'] and settings['retry_secondary']:
+ self._set_next_host_location(settings, request)
+
+ # rewind the request body if it is a stream
+ if request.body and hasattr(request.body, 'read'):
+ # no position was saved, then retry would not work
+ if settings['body_position'] is None:
+ return False
+ try:
+ # attempt to rewind the body to the initial position
+ request.body.seek(settings['body_position'], SEEK_SET)
+ except (UnsupportedOperation, ValueError):
+ # if body is not seekable, then retry would not work
+ return False
+ settings['count'] += 1
+ return True
+ return False
+
+ def send(self, request):
+ retries_remaining = True
+ response = None
+ retry_settings = self.configure_retries(request)
+ while retries_remaining:
+ try:
+ response = self.next.send(request)
+ if is_retry(response, retry_settings['mode']) or is_checksum_retry(response):
+ retries_remaining = self.increment(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response)
+ if retries_remaining:
+ retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response,
+ error=None)
+ self.sleep(retry_settings, request.context.transport)
+ continue
+ break
+ except AzureError as err:
+ if isinstance(err, AzureSigningError):
+ raise
+ retries_remaining = self.increment(
+ retry_settings, request=request.http_request, error=err)
+ if retries_remaining:
+ retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=None,
+ error=err)
+ self.sleep(retry_settings, request.context.transport)
+ continue
+ raise err
+ if retry_settings['history']:
+ response.context['history'] = retry_settings['history']
+ response.http_response.location_mode = retry_settings['mode']
+ return response
+
+
+class ExponentialRetry(StorageRetryPolicy):
+ """Exponential retry."""
+
+ initial_backoff: int
+ """The initial backoff interval, in seconds, for the first retry."""
+ increment_base: int
+ """The base, in seconds, to increment the initial_backoff by after the
+ first retry."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, initial_backoff: int = 15,
+ increment_base: int = 3,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs an Exponential retry object. The initial_backoff is used for
+ the first retry. Subsequent retries are retried after initial_backoff +
+ increment_power^retry_count seconds.
+
+ :param int initial_backoff:
+ The initial backoff interval, in seconds, for the first retry.
+ :param int increment_base:
+ The base, in seconds, to increment the initial_backoff by after the
+ first retry.
+ :param int retry_total:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.initial_backoff = initial_backoff
+ self.increment_base = increment_base
+ self.random_jitter_range = random_jitter_range
+ super(ExponentialRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time.
+ :returns:
+ A float indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: float
+ """
+ random_generator = random.Random()
+ backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
+ random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
+ random_range_end = backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class LinearRetry(StorageRetryPolicy):
+ """Linear retry."""
+
+ initial_backoff: int
+ """The backoff interval, in seconds, between retries."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, backoff: int = 15,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs a Linear retry object.
+
+ :param int backoff:
+ The backoff interval, in seconds, between retries.
+ :param int retry_total:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.backoff = backoff
+ self.random_jitter_range = random_jitter_range
+ super(LinearRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
+ :returns:
+ A float indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: float
+ """
+ random_generator = random.Random()
+ # the backoff interval normally does not change, however there is the possibility
+ # that it was modified by accessing the property directly after initializing the object
+ random_range_start = self.backoff - self.random_jitter_range \
+ if self.backoff > self.random_jitter_range else 0
+ random_range_end = self.backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
+ """ Custom Bearer token credential policy for following Storage Bearer challenges """
+
+ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None:
+ super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
+
+ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
+ try:
+ auth_header = response.http_response.headers.get("WWW-Authenticate")
+ challenge = StorageHttpChallenge(auth_header)
+ except ValueError:
+ return False
+
+ scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
+ self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
+
+ return True
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py
new file mode 100644
index 00000000..1c030a82
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/policies_async.py
@@ -0,0 +1,296 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=invalid-overridden-method
+
+import asyncio
+import logging
+import random
+from typing import Any, Dict, TYPE_CHECKING
+
+from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError
+from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy
+
+from .authentication import AzureSigningError, StorageHttpChallenge
+from .constants import DEFAULT_OAUTH_SCOPE
+from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy
+
+if TYPE_CHECKING:
+ from azure.core.credentials_async import AsyncTokenCredential
+ from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
+ PipelineRequest,
+ PipelineResponse
+ )
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+async def retry_hook(settings, **kwargs):
+ if settings['hook']:
+ if asyncio.iscoroutine(settings['hook']):
+ await settings['hook'](
+ retry_count=settings['count'] - 1,
+ location_mode=settings['mode'],
+ **kwargs)
+ else:
+ settings['hook'](
+ retry_count=settings['count'] - 1,
+ location_mode=settings['mode'],
+ **kwargs)
+
+
+async def is_checksum_retry(response):
+ # retry if invalid content md5
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ try:
+ await response.http_response.load_body() # Load the body in memory and close the socket
+ except (StreamClosedError, StreamConsumedError):
+ pass
+ computed_md5 = response.http_request.headers.get('content-md5', None) or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ return True
+ return False
+
+
+class AsyncStorageResponseHook(AsyncHTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._response_callback = kwargs.get('raw_response_hook')
+ super(AsyncStorageResponseHook, self).__init__()
+
+ async def send(self, request: "PipelineRequest") -> "PipelineResponse":
+ # Values could be 0
+ data_stream_total = request.context.get('data_stream_total')
+ if data_stream_total is None:
+ data_stream_total = request.context.options.pop('data_stream_total', None)
+ download_stream_current = request.context.get('download_stream_current')
+ if download_stream_current is None:
+ download_stream_current = request.context.options.pop('download_stream_current', None)
+ upload_stream_current = request.context.get('upload_stream_current')
+ if upload_stream_current is None:
+ upload_stream_current = request.context.options.pop('upload_stream_current', None)
+
+ response_callback = request.context.get('response_callback') or \
+ request.context.options.pop('raw_response_hook', self._response_callback)
+
+ response = await self.next.send(request)
+
+ will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response)
+ # Auth error could come from Bearer challenge, in which case this request will be made again
+ is_auth_error = response.http_response.status_code == 401
+ should_update_counts = not (will_retry or is_auth_error)
+
+ if should_update_counts and download_stream_current is not None:
+ download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
+ if data_stream_total is None:
+ content_range = response.http_response.headers.get('Content-Range')
+ if content_range:
+ data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
+ else:
+ data_stream_total = download_stream_current
+ elif should_update_counts and upload_stream_current is not None:
+ upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
+ for pipeline_obj in [request, response]:
+ if hasattr(pipeline_obj, 'context'):
+ pipeline_obj.context['data_stream_total'] = data_stream_total
+ pipeline_obj.context['download_stream_current'] = download_stream_current
+ pipeline_obj.context['upload_stream_current'] = upload_stream_current
+ if response_callback:
+ if asyncio.iscoroutine(response_callback):
+ await response_callback(response) # type: ignore
+ else:
+ response_callback(response)
+ request.context['response_callback'] = response_callback
+ return response
+
+class AsyncStorageRetryPolicy(StorageRetryPolicy):
+ """
+ The base class for Exponential and Linear retries containing shared code.
+ """
+
+ async def sleep(self, settings, transport):
+ backoff = self.get_backoff_time(settings)
+ if not backoff or backoff < 0:
+ return
+ await transport.sleep(backoff)
+
+ async def send(self, request):
+ retries_remaining = True
+ response = None
+ retry_settings = self.configure_retries(request)
+ while retries_remaining:
+ try:
+ response = await self.next.send(request)
+ if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response):
+ retries_remaining = self.increment(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response)
+ if retries_remaining:
+ await retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response,
+ error=None)
+ await self.sleep(retry_settings, request.context.transport)
+ continue
+ break
+ except AzureError as err:
+ if isinstance(err, AzureSigningError):
+ raise
+ retries_remaining = self.increment(
+ retry_settings, request=request.http_request, error=err)
+ if retries_remaining:
+ await retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=None,
+ error=err)
+ await self.sleep(retry_settings, request.context.transport)
+ continue
+ raise err
+ if retry_settings['history']:
+ response.context['history'] = retry_settings['history']
+ response.http_response.location_mode = retry_settings['mode']
+ return response
+
+
+class ExponentialRetry(AsyncStorageRetryPolicy):
+ """Exponential retry."""
+
+ initial_backoff: int
+ """The initial backoff interval, in seconds, for the first retry."""
+ increment_base: int
+ """The base, in seconds, to increment the initial_backoff by after the
+ first retry."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self,
+ initial_backoff: int = 15,
+ increment_base: int = 3,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3, **kwargs
+ ) -> None:
+ """
+ Constructs an Exponential retry object. The initial_backoff is used for
+ the first retry. Subsequent retries are retried after initial_backoff +
+ increment_power^retry_count seconds. For example, by default the first retry
+ occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the
+ third after (15+3^2) = 24 seconds.
+
+ :param int initial_backoff:
+ The initial backoff interval, in seconds, for the first retry.
+ :param int increment_base:
+ The base, in seconds, to increment the initial_backoff by after the
+ first retry.
+ :param int max_attempts:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.initial_backoff = initial_backoff
+ self.increment_base = increment_base
+ self.random_jitter_range = random_jitter_range
+ super(ExponentialRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
+ :return:
+ An integer indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: int or None
+ """
+ random_generator = random.Random()
+ backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
+ random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
+ random_range_end = backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class LinearRetry(AsyncStorageRetryPolicy):
+ """Linear retry."""
+
+ initial_backoff: int
+ """The backoff interval, in seconds, between retries."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, backoff: int = 15,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs a Linear retry object.
+
+ :param int backoff:
+ The backoff interval, in seconds, between retries.
+ :param int max_attempts:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.backoff = backoff
+ self.random_jitter_range = random_jitter_range
+ super(LinearRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
+ :return:
+ An integer indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: int or None
+ """
+ random_generator = random.Random()
+ # the backoff interval normally does not change, however there is the possibility
+ # that it was modified by accessing the property directly after initializing the object
+ random_range_start = self.backoff - self.random_jitter_range \
+ if self.backoff > self.random_jitter_range else 0
+ random_range_end = self.backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
+ """ Custom Bearer token credential policy for following Storage Bearer challenges """
+
+ def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None:
+ super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
+
+ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
+ try:
+ auth_header = response.http_response.headers.get("WWW-Authenticate")
+ challenge = StorageHttpChallenge(auth_header)
+ except ValueError:
+ return False
+
+ scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
+ await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
+
+ return True
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py
new file mode 100644
index 00000000..54927cc7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/request_handlers.py
@@ -0,0 +1,270 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import logging
+import stat
+from io import (SEEK_END, SEEK_SET, UnsupportedOperation)
+from os import fstat
+from typing import Dict, Optional
+
+import isodate
+
+
+_LOGGER = logging.getLogger(__name__)
+
+_REQUEST_DELIMITER_PREFIX = "batch_"
+_HTTP1_1_IDENTIFIER = "HTTP/1.1"
+_HTTP_LINE_ENDING = "\r\n"
+
+
+def serialize_iso(attr):
+ """Serialize Datetime object into ISO-8601 formatted string.
+
+ :param Datetime attr: Object to be serialized.
+ :rtype: str
+ :raises: ValueError if format invalid.
+ """
+ if not attr:
+ return None
+ if isinstance(attr, str):
+ attr = isodate.parse_datetime(attr)
+ try:
+ utc = attr.utctimetuple()
+ if utc.tm_year > 9999 or utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+
+ date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}"
+ return date + 'Z'
+ except (ValueError, OverflowError) as err:
+ raise ValueError("Unable to serialize datetime object.") from err
+ except AttributeError as err:
+ raise TypeError("ISO-8601 object must be valid datetime object.") from err
+
+def get_length(data):
+ length = None
+ # Check if object implements the __len__ method, covers most input cases such as bytearray.
+ try:
+ length = len(data)
+ except: # pylint: disable=bare-except
+ pass
+
+ if not length:
+ # Check if the stream is a file-like stream object.
+ # If so, calculate the size using the file descriptor.
+ try:
+ fileno = data.fileno()
+ except (AttributeError, UnsupportedOperation):
+ pass
+ else:
+ try:
+ mode = fstat(fileno).st_mode
+ if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
+ #st_size only meaningful if regular file or symlink, other types
+ # e.g. sockets may return misleading sizes like 0
+ return fstat(fileno).st_size
+ except OSError:
+ # Not a valid fileno, may be possible requests returned
+ # a socket number?
+ pass
+
+ # If the stream is seekable and tell() is implemented, calculate the stream size.
+ try:
+ current_position = data.tell()
+ data.seek(0, SEEK_END)
+ length = data.tell() - current_position
+ data.seek(current_position, SEEK_SET)
+ except (AttributeError, OSError, UnsupportedOperation):
+ pass
+
+ return length
+
+
+def read_length(data):
+ try:
+ if hasattr(data, 'read'):
+ read_data = b''
+ for chunk in iter(lambda: data.read(4096), b""):
+ read_data += chunk
+ return len(read_data), read_data
+ if hasattr(data, '__iter__'):
+ read_data = b''
+ for chunk in data:
+ read_data += chunk
+ return len(read_data), read_data
+ except: # pylint: disable=bare-except
+ pass
+ raise ValueError("Unable to calculate content length, please specify.")
+
+
+def validate_and_format_range_headers(
+ start_range, end_range, start_range_required=True,
+ end_range_required=True, check_content_md5=False, align_to_page=False):
+ # If end range is provided, start range must be provided
+ if (start_range_required or end_range is not None) and start_range is None:
+ raise ValueError("start_range value cannot be None.")
+ if end_range_required and end_range is None:
+ raise ValueError("end_range value cannot be None.")
+
+ # Page ranges must be 512 aligned
+ if align_to_page:
+ if start_range is not None and start_range % 512 != 0:
+ raise ValueError(f"Invalid page blob start_range: {start_range}. "
+ "The size must be aligned to a 512-byte boundary.")
+ if end_range is not None and end_range % 512 != 511:
+ raise ValueError(f"Invalid page blob end_range: {end_range}. "
+ "The size must be aligned to a 512-byte boundary.")
+
+ # Format based on whether end_range is present
+ range_header = None
+ if end_range is not None:
+ range_header = f'bytes={start_range}-{end_range}'
+ elif start_range is not None:
+ range_header = f"bytes={start_range}-"
+
+ # Content MD5 can only be provided for a complete range less than 4MB in size
+ range_validation = None
+ if check_content_md5:
+ if start_range is None or end_range is None:
+ raise ValueError("Both start and end range required for MD5 content validation.")
+ if end_range - start_range > 4 * 1024 * 1024:
+ raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.")
+ range_validation = 'true'
+
+ return range_header, range_validation
+
+
+def add_metadata_headers(metadata=None):
+ # type: (Optional[Dict[str, str]]) -> Dict[str, str]
+ headers = {}
+ if metadata:
+ for key, value in metadata.items():
+ headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value
+ return headers
+
+
+def serialize_batch_body(requests, batch_id):
+ """
+ --<delimiter>
+ <subrequest>
+ --<delimiter>
+ <subrequest> (repeated as needed)
+ --<delimiter>--
+
+ Serializes the requests in this batch to a single HTTP mixed/multipart body.
+
+ :param List[~azure.core.pipeline.transport.HttpRequest] requests:
+ a list of sub-request for the batch request
+ :param str batch_id:
+ to be embedded in batch sub-request delimiter
+ :returns: The body bytes for this batch.
+ :rtype: bytes
+ """
+
+ if requests is None or len(requests) == 0:
+ raise ValueError('Please provide sub-request(s) for this batch request')
+
+ delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8')
+ newline_bytes = _HTTP_LINE_ENDING.encode('utf-8')
+ batch_body = []
+
+ content_index = 0
+ for request in requests:
+ request.headers.update({
+ "Content-ID": str(content_index),
+ "Content-Length": str(0)
+ })
+ batch_body.append(delimiter_bytes)
+ batch_body.append(_make_body_from_sub_request(request))
+ batch_body.append(newline_bytes)
+ content_index += 1
+
+ batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8'))
+ # final line of body MUST have \r\n at the end, or it will not be properly read by the service
+ batch_body.append(newline_bytes)
+
+ return b"".join(batch_body)
+
+
+def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_dashes=False):
+ """
+ Gets the delimiter used for this batch request's mixed/multipart HTTP format.
+
+ :param str batch_id:
+ Randomly generated id
+ :param bool is_prepend_dashes:
+ Whether to include the starting dashes. Used in the body, but non on defining the delimiter.
+ :param bool is_append_dashes:
+ Whether to include the ending dashes. Used in the body on the closing delimiter only.
+ :returns: The delimiter, WITHOUT a trailing newline.
+ :rtype: str
+ """
+
+ prepend_dashes = '--' if is_prepend_dashes else ''
+ append_dashes = '--' if is_append_dashes else ''
+
+ return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes
+
+
+def _make_body_from_sub_request(sub_request):
+ """
+ Content-Type: application/http
+ Content-ID: <sequential int ID>
+ Content-Transfer-Encoding: <value> (if present)
+
+ <verb> <path><query> HTTP/<version>
+ <header key>: <header value> (repeated as necessary)
+ Content-Length: <value>
+ (newline if content length > 0)
+ <body> (if content length > 0)
+
+ Serializes an http request.
+
+ :param ~azure.core.pipeline.transport.HttpRequest sub_request:
+ Request to serialize.
+ :returns: The serialized sub-request in bytes
+ :rtype: bytes
+ """
+
+ # put the sub-request's headers into a list for efficient str concatenation
+ sub_request_body = []
+
+ # get headers for ease of manipulation; remove headers as they are used
+ headers = sub_request.headers
+
+ # append opening headers
+ sub_request_body.append("Content-Type: application/http")
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ sub_request_body.append("Content-ID: ")
+ sub_request_body.append(headers.pop("Content-ID", ""))
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ sub_request_body.append("Content-Transfer-Encoding: binary")
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append blank line
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append HTTP verb and path and query and HTTP version
+ sub_request_body.append(sub_request.method)
+ sub_request_body.append(' ')
+ sub_request_body.append(sub_request.url)
+ sub_request_body.append(' ')
+ sub_request_body.append(_HTTP1_1_IDENTIFIER)
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append remaining headers (this will set the Content-Length, as it was set on `sub-request`)
+ for header_name, header_value in headers.items():
+ if header_value is not None:
+ sub_request_body.append(header_name)
+ sub_request_body.append(": ")
+ sub_request_body.append(header_value)
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append blank line
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ return ''.join(sub_request_body).encode()
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py
new file mode 100644
index 00000000..af9a2fcd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/response_handlers.py
@@ -0,0 +1,200 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import logging
+from typing import NoReturn
+from xml.etree.ElementTree import Element
+
+from azure.core.exceptions import (
+ ClientAuthenticationError,
+ DecodeError,
+ HttpResponseError,
+ ResourceExistsError,
+ ResourceModifiedError,
+ ResourceNotFoundError,
+)
+from azure.core.pipeline.policies import ContentDecodePolicy
+
+from .authentication import AzureSigningError
+from .models import get_enum_value, StorageErrorCode, UserDelegationKey
+from .parser import _to_utc_datetime
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PartialBatchErrorException(HttpResponseError):
+ """There is a partial failure in batch operations.
+
+ :param str message: The message of the exception.
+ :param response: Server response to be deserialized.
+ :param list parts: A list of the parts in multipart response.
+ """
+
+ def __init__(self, message, response, parts):
+ self.parts = parts
+ super(PartialBatchErrorException, self).__init__(message=message, response=response)
+
+
+# Parses the blob length from the content range header: bytes 1-3/65537
+def parse_length_from_content_range(content_range):
+ if content_range is None:
+ return None
+
+ # First, split in space and take the second half: '1-3/65537'
+ # Next, split on slash and take the second half: '65537'
+ # Finally, convert to an int: 65537
+ return int(content_range.split(' ', 1)[1].split('/', 1)[1])
+
+
+def normalize_headers(headers):
+ normalized = {}
+ for key, value in headers.items():
+ if key.startswith('x-ms-'):
+ key = key[5:]
+ normalized[key.lower().replace('-', '_')] = get_enum_value(value)
+ return normalized
+
+
+def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument
+ try:
+ raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')}
+ except AttributeError:
+ raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')}
+ return {k[10:]: v for k, v in raw_metadata.items()}
+
+
+def return_response_headers(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return normalize_headers(response_headers)
+
+
+def return_headers_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return normalize_headers(response_headers), deserialized
+
+
+def return_context_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return response.http_response.location_mode, deserialized
+
+
+def return_raw_deserialized(response, *_):
+ return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]
+
+
+def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
+ raise_error = HttpResponseError
+ serialized = False
+ if isinstance(storage_error, AzureSigningError):
+ storage_error.message = storage_error.message + \
+ '. This is likely due to an invalid shared key. Please check your shared key and try again.'
+ if not storage_error.response or storage_error.response.status_code in [200, 204]:
+ raise storage_error
+ # If it is one of those three then it has been serialized prior by the generated layer.
+ if isinstance(storage_error, (PartialBatchErrorException,
+ ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)):
+ serialized = True
+ error_code = storage_error.response.headers.get('x-ms-error-code')
+ error_message = storage_error.message
+ additional_data = {}
+ error_dict = {}
+ try:
+ error_body = ContentDecodePolicy.deserialize_from_http_generics(storage_error.response)
+ try:
+ if error_body is None or len(error_body) == 0:
+ error_body = storage_error.response.reason
+ except AttributeError:
+ error_body = ''
+ # If it is an XML response
+ if isinstance(error_body, Element):
+ error_dict = {
+ child.tag.lower(): child.text
+ for child in error_body
+ }
+ # If it is a JSON response
+ elif isinstance(error_body, dict):
+ error_dict = error_body.get('error', {})
+ elif not error_code:
+ _LOGGER.warning(
+ 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body))
+ error_dict = {'message': str(error_body)}
+
+ # If we extracted from a Json or XML response
+ # There is a chance error_dict is just a string
+ if error_dict and isinstance(error_dict, dict):
+ error_code = error_dict.get('code')
+ error_message = error_dict.get('message')
+ additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}}
+ except DecodeError:
+ pass
+
+ try:
+ # This check would be unnecessary if we have already serialized the error
+ if error_code and not serialized:
+ error_code = StorageErrorCode(error_code)
+ if error_code in [StorageErrorCode.condition_not_met,
+ StorageErrorCode.blob_overwritten]:
+ raise_error = ResourceModifiedError
+ if error_code in [StorageErrorCode.invalid_authentication_info,
+ StorageErrorCode.authentication_failed]:
+ raise_error = ClientAuthenticationError
+ if error_code in [StorageErrorCode.resource_not_found,
+ StorageErrorCode.cannot_verify_copy_source,
+ StorageErrorCode.blob_not_found,
+ StorageErrorCode.queue_not_found,
+ StorageErrorCode.container_not_found,
+ StorageErrorCode.parent_not_found,
+ StorageErrorCode.share_not_found]:
+ raise_error = ResourceNotFoundError
+ if error_code in [StorageErrorCode.account_already_exists,
+ StorageErrorCode.account_being_created,
+ StorageErrorCode.resource_already_exists,
+ StorageErrorCode.resource_type_mismatch,
+ StorageErrorCode.blob_already_exists,
+ StorageErrorCode.queue_already_exists,
+ StorageErrorCode.container_already_exists,
+ StorageErrorCode.container_being_deleted,
+ StorageErrorCode.queue_being_deleted,
+ StorageErrorCode.share_already_exists,
+ StorageErrorCode.share_being_deleted]:
+ raise_error = ResourceExistsError
+ except ValueError:
+ # Got an unknown error code
+ pass
+
+ # Error message should include all the error properties
+ try:
+ error_message += f"\nErrorCode:{error_code.value}"
+ except AttributeError:
+ error_message += f"\nErrorCode:{error_code}"
+ for name, info in additional_data.items():
+ error_message += f"\n{name}:{info}"
+
+ # No need to create an instance if it has already been serialized by the generated layer
+ if serialized:
+ storage_error.message = error_message
+ error = storage_error
+ else:
+ error = raise_error(message=error_message, response=storage_error.response)
+ # Ensure these properties are stored in the error instance as well (not just the error message)
+ error.error_code = error_code
+ error.additional_info = additional_data
+ # error.args is what's surfaced on the traceback - show error message in all cases
+ error.args = (error.message,)
+ try:
+ # `from None` prevents us from double printing the exception (suppresses generated layer error context)
+ exec("raise error from None") # pylint: disable=exec-used # nosec
+ except SyntaxError as exc:
+ raise error from exc
+
+
+def parse_to_internal_user_delegation_key(service_user_delegation_key):
+ internal_user_delegation_key = UserDelegationKey()
+ internal_user_delegation_key.signed_oid = service_user_delegation_key.signed_oid
+ internal_user_delegation_key.signed_tid = service_user_delegation_key.signed_tid
+ internal_user_delegation_key.signed_start = _to_utc_datetime(service_user_delegation_key.signed_start)
+ internal_user_delegation_key.signed_expiry = _to_utc_datetime(service_user_delegation_key.signed_expiry)
+ internal_user_delegation_key.signed_service = service_user_delegation_key.signed_service
+ internal_user_delegation_key.signed_version = service_user_delegation_key.signed_version
+ internal_user_delegation_key.value = service_user_delegation_key.value
+ return internal_user_delegation_key
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py
new file mode 100644
index 00000000..2ef9921d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/shared_access_signature.py
@@ -0,0 +1,243 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from datetime import date
+
+from .parser import _to_utc_datetime
+from .constants import X_MS_VERSION
+from . import sign_string, url_quote
+
+# cspell:ignoreRegExp rsc.
+# cspell:ignoreRegExp s..?id
+class QueryStringConstants(object):
+ SIGNED_SIGNATURE = 'sig'
+ SIGNED_PERMISSION = 'sp'
+ SIGNED_START = 'st'
+ SIGNED_EXPIRY = 'se'
+ SIGNED_RESOURCE = 'sr'
+ SIGNED_IDENTIFIER = 'si'
+ SIGNED_IP = 'sip'
+ SIGNED_PROTOCOL = 'spr'
+ SIGNED_VERSION = 'sv'
+ SIGNED_CACHE_CONTROL = 'rscc'
+ SIGNED_CONTENT_DISPOSITION = 'rscd'
+ SIGNED_CONTENT_ENCODING = 'rsce'
+ SIGNED_CONTENT_LANGUAGE = 'rscl'
+ SIGNED_CONTENT_TYPE = 'rsct'
+ START_PK = 'spk'
+ START_RK = 'srk'
+ END_PK = 'epk'
+ END_RK = 'erk'
+ SIGNED_RESOURCE_TYPES = 'srt'
+ SIGNED_SERVICES = 'ss'
+ SIGNED_OID = 'skoid'
+ SIGNED_TID = 'sktid'
+ SIGNED_KEY_START = 'skt'
+ SIGNED_KEY_EXPIRY = 'ske'
+ SIGNED_KEY_SERVICE = 'sks'
+ SIGNED_KEY_VERSION = 'skv'
+ SIGNED_ENCRYPTION_SCOPE = 'ses'
+
+ # for ADLS
+ SIGNED_AUTHORIZED_OID = 'saoid'
+ SIGNED_UNAUTHORIZED_OID = 'suoid'
+ SIGNED_CORRELATION_ID = 'scid'
+ SIGNED_DIRECTORY_DEPTH = 'sdd'
+
+ @staticmethod
+ def to_list():
+ return [
+ QueryStringConstants.SIGNED_SIGNATURE,
+ QueryStringConstants.SIGNED_PERMISSION,
+ QueryStringConstants.SIGNED_START,
+ QueryStringConstants.SIGNED_EXPIRY,
+ QueryStringConstants.SIGNED_RESOURCE,
+ QueryStringConstants.SIGNED_IDENTIFIER,
+ QueryStringConstants.SIGNED_IP,
+ QueryStringConstants.SIGNED_PROTOCOL,
+ QueryStringConstants.SIGNED_VERSION,
+ QueryStringConstants.SIGNED_CACHE_CONTROL,
+ QueryStringConstants.SIGNED_CONTENT_DISPOSITION,
+ QueryStringConstants.SIGNED_CONTENT_ENCODING,
+ QueryStringConstants.SIGNED_CONTENT_LANGUAGE,
+ QueryStringConstants.SIGNED_CONTENT_TYPE,
+ QueryStringConstants.START_PK,
+ QueryStringConstants.START_RK,
+ QueryStringConstants.END_PK,
+ QueryStringConstants.END_RK,
+ QueryStringConstants.SIGNED_RESOURCE_TYPES,
+ QueryStringConstants.SIGNED_SERVICES,
+ QueryStringConstants.SIGNED_OID,
+ QueryStringConstants.SIGNED_TID,
+ QueryStringConstants.SIGNED_KEY_START,
+ QueryStringConstants.SIGNED_KEY_EXPIRY,
+ QueryStringConstants.SIGNED_KEY_SERVICE,
+ QueryStringConstants.SIGNED_KEY_VERSION,
+ QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
+ # for ADLS
+ QueryStringConstants.SIGNED_AUTHORIZED_OID,
+ QueryStringConstants.SIGNED_UNAUTHORIZED_OID,
+ QueryStringConstants.SIGNED_CORRELATION_ID,
+ QueryStringConstants.SIGNED_DIRECTORY_DEPTH,
+ ]
+
+
+class SharedAccessSignature(object):
+ '''
+ Provides a factory for creating account access
+ signature tokens with an account name and account key. Users can either
+ use the factory or can construct the appropriate service and use the
+ generate_*_shared_access_signature method directly.
+ '''
+
+ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION):
+ '''
+ :param str account_name:
+ The storage account name used to generate the shared access signatures.
+ :param str account_key:
+ The access key to generate the shares access signatures.
+ :param str x_ms_version:
+ The service version used to generate the shared access signatures.
+ '''
+ self.account_name = account_name
+ self.account_key = account_key
+ self.x_ms_version = x_ms_version
+
+ def generate_account(
+ self, services,
+ resource_types,
+ permission,
+ expiry,
+ start=None,
+ ip=None,
+ protocol=None,
+ sts_hook=None
+ ) -> str:
+ '''
+ Generates a shared access signature for the account.
+ Use the returned signature with the sas_token parameter of the service
+ or to create a new account object.
+
+ :param Any services: The specified services associated with the shared access signature.
+ :param ResourceTypes resource_types:
+ Specifies the resource types that are accessible with the account
+ SAS. You can combine values to provide access to more than one
+ resource type.
+ :param AccountSasPermissions permission:
+ The permissions associated with the shared access signature. The
+ user is restricted to operations allowed by the permissions.
+ Required unless an id is given referencing a stored access policy
+ which contains this field. This field must be omitted if it has been
+ specified in an associated stored access policy. You can combine
+ values to provide more than one permission.
+ :param expiry:
+ The time at which the shared access signature becomes invalid.
+ Required unless an id is given referencing a stored access policy
+ which contains this field. This field must be omitted if it has
+ been specified in an associated stored access policy. Azure will always
+ convert values to UTC. If a date is passed in without timezone info, it
+ is assumed to be UTC.
+ :type expiry: datetime or str
+ :param start:
+ The time at which the shared access signature becomes valid. If
+ omitted, start time for this call is assumed to be the time when the
+ storage service receives the request. The provided datetime will always
+ be interpreted as UTC.
+ :type start: datetime or str
+ :param str ip:
+ Specifies an IP address or a range of IP addresses from which to accept requests.
+ If the IP address from which the request originates does not match the IP address
+ or address range specified on the SAS token, the request is not authenticated.
+ For example, specifying sip=168.1.5.65 or sip=168.1.5.60-168.1.5.70 on the SAS
+ restricts the request to those IP addresses.
+ :param str protocol:
+ Specifies the protocol permitted for a request made. The default value
+ is https,http. See :class:`~azure.storage.common.models.Protocol` for possible values.
+ :param sts_hook:
+ For debugging purposes only. If provided, the hook is called with the string to sign
+ that was used to generate the SAS.
+ :type sts_hook: Optional[Callable[[str], None]]
+ :returns: The generated SAS token for the account.
+ :rtype: str
+ '''
+ sas = _SharedAccessHelper()
+ sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version)
+ sas.add_account(services, resource_types)
+ sas.add_account_signature(self.account_name, self.account_key)
+
+ if sts_hook is not None:
+ sts_hook(sas.string_to_sign)
+
+ return sas.get_token()
+
+
+class _SharedAccessHelper(object):
+ def __init__(self):
+ self.query_dict = {}
+ self.string_to_sign = ""
+
+ def _add_query(self, name, val):
+ if val:
+ self.query_dict[name] = str(val) if val is not None else None
+
+ def add_base(self, permission, expiry, start, ip, protocol, x_ms_version):
+ if isinstance(start, date):
+ start = _to_utc_datetime(start)
+
+ if isinstance(expiry, date):
+ expiry = _to_utc_datetime(expiry)
+
+ self._add_query(QueryStringConstants.SIGNED_START, start)
+ self._add_query(QueryStringConstants.SIGNED_EXPIRY, expiry)
+ self._add_query(QueryStringConstants.SIGNED_PERMISSION, permission)
+ self._add_query(QueryStringConstants.SIGNED_IP, ip)
+ self._add_query(QueryStringConstants.SIGNED_PROTOCOL, protocol)
+ self._add_query(QueryStringConstants.SIGNED_VERSION, x_ms_version)
+
+ def add_resource(self, resource):
+ self._add_query(QueryStringConstants.SIGNED_RESOURCE, resource)
+
+ def add_id(self, policy_id):
+ self._add_query(QueryStringConstants.SIGNED_IDENTIFIER, policy_id)
+
+ def add_account(self, services, resource_types):
+ self._add_query(QueryStringConstants.SIGNED_SERVICES, services)
+ self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types)
+
+ def add_override_response_headers(self, cache_control,
+ content_disposition,
+ content_encoding,
+ content_language,
+ content_type):
+ self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)
+
+ def add_account_signature(self, account_name, account_key):
+ def get_value_to_append(query):
+ return_value = self.query_dict.get(query) or ''
+ return return_value + '\n'
+
+ self.string_to_sign = \
+ (account_name + '\n' +
+ get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) +
+ get_value_to_append(QueryStringConstants.SIGNED_SERVICES) +
+ get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) +
+ get_value_to_append(QueryStringConstants.SIGNED_START) +
+ get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) +
+ get_value_to_append(QueryStringConstants.SIGNED_IP) +
+ get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) +
+ get_value_to_append(QueryStringConstants.SIGNED_VERSION) +
+ '\n' # Signed Encryption Scope - always empty for fileshare
+ )
+
+ self._add_query(QueryStringConstants.SIGNED_SIGNATURE,
+ sign_string(account_key, self.string_to_sign))
+
+ def get_token(self) -> str:
+ return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None])
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py
new file mode 100644
index 00000000..b31cfb32
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads.py
@@ -0,0 +1,604 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from concurrent import futures
+from io import BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation
+from itertools import islice
+from math import ceil
+from threading import Lock
+
+from azure.core.tracing.common import with_current_context
+
+from .import encode_base64, url_quote
+from .request_handlers import get_length
+from .response_handlers import return_response_headers
+
+
+_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024
+_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = "{0} should be a seekable file-like/io.IOBase type stream object."
+
+
+def _parallel_uploads(executor, uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = next(pending)
+ running.add(executor.submit(with_current_context(uploader), next_chunk))
+ except StopIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ done, _running = futures.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+def upload_data_chunks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ validate_content=None,
+ progress_hook=None,
+ **kwargs):
+
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ validate_content=validate_content,
+ progress_hook=progress_hook,
+ **kwargs)
+ if parallel:
+ with futures.ThreadPoolExecutor(max_concurrency) as executor:
+ upload_tasks = uploader.get_chunk_streams()
+ running_futures = [
+ executor.submit(with_current_context(uploader.process_chunk), u)
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = _parallel_uploads(executor, uploader.process_chunk, upload_tasks, running_futures)
+ else:
+ range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()]
+ if any(range_ids):
+ return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
+ return uploader.response_headers
+
+
+def upload_substream_blocks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ with futures.ThreadPoolExecutor(max_concurrency) as executor:
+ upload_tasks = uploader.get_substream_blocks()
+ running_futures = [
+ executor.submit(with_current_context(uploader.process_substream_block), u)
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = _parallel_uploads(executor, uploader.process_substream_block, upload_tasks, running_futures)
+ else:
+ range_ids = [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()]
+ if any(range_ids):
+ return sorted(range_ids)
+ return []
+
+
+class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self, service,
+ total_size,
+ chunk_size,
+ stream,
+ parallel,
+ encryptor=None,
+ padder=None,
+ progress_hook=None,
+ **kwargs):
+ self.service = service
+ self.total_size = total_size
+ self.chunk_size = chunk_size
+ self.stream = stream
+ self.parallel = parallel
+
+ # Stream management
+ self.stream_lock = Lock() if parallel else None
+
+ # Progress feedback
+ self.progress_total = 0
+ self.progress_lock = Lock() if parallel else None
+ self.progress_hook = progress_hook
+
+ # Encryption
+ self.encryptor = encryptor
+ self.padder = padder
+ self.response_headers = None
+ self.etag = None
+ self.last_modified = None
+ self.request_options = kwargs
+
+ def get_chunk_streams(self):
+ index = 0
+ while True:
+ data = b""
+ read_size = self.chunk_size
+
+ # Buffer until we either reach the end of the stream or get a whole chunk.
+ while True:
+ if self.total_size:
+ read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
+ temp = self.stream.read(read_size)
+ if not isinstance(temp, bytes):
+ raise TypeError("Blob data should be of type bytes.")
+ data += temp or b""
+
+ # We have read an empty string and so are at the end
+ # of the buffer or we have read a full chunk.
+ if temp == b"" or len(data) == self.chunk_size:
+ break
+
+ if len(data) == self.chunk_size:
+ if self.padder:
+ data = self.padder.update(data)
+ if self.encryptor:
+ data = self.encryptor.update(data)
+ yield index, data
+ else:
+ if self.padder:
+ data = self.padder.update(data) + self.padder.finalize()
+ if self.encryptor:
+ data = self.encryptor.update(data) + self.encryptor.finalize()
+ if data:
+ yield index, data
+ break
+ index += len(data)
+
+ def process_chunk(self, chunk_data):
+ chunk_bytes = chunk_data[1]
+ chunk_offset = chunk_data[0]
+ return self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
+
+ def _update_progress(self, length):
+ if self.progress_lock is not None:
+ with self.progress_lock:
+ self.progress_total += length
+ else:
+ self.progress_total += length
+
+ if self.progress_hook:
+ self.progress_hook(self.progress_total, self.total_size)
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
+ range_id = self._upload_chunk(chunk_offset, chunk_data)
+ self._update_progress(len(chunk_data))
+ return range_id
+
+ def get_substream_blocks(self):
+ assert self.chunk_size is not None
+ lock = self.stream_lock
+ blob_length = self.total_size
+
+ if blob_length is None:
+ blob_length = get_length(self.stream)
+ if blob_length is None:
+ raise ValueError("Unable to determine content length of upload data.")
+
+ blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
+ last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
+
+ for i in range(blocks):
+ index = i * self.chunk_size
+ length = last_block_size if i == blocks - 1 else self.chunk_size
+ yield index, SubStream(self.stream, index, length, lock)
+
+ def process_substream_block(self, block_data):
+ return self._upload_substream_block_with_progress(block_data[0], block_data[1])
+
+ def _upload_substream_block(self, index, block_stream):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ def _upload_substream_block_with_progress(self, index, block_stream):
+ range_id = self._upload_substream_block(index, block_stream)
+ self._update_progress(len(block_stream))
+ return range_id
+
+ def set_response_properties(self, resp):
+ self.etag = resp.etag
+ self.last_modified = resp.last_modified
+
+
+class BlockBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ kwargs.pop("modified_access_conditions", None)
+ super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # TODO: This is incorrect, but works with recording.
+ index = f'{chunk_offset:032d}'
+ block_id = encode_base64(url_quote(encode_base64(index)))
+ self.service.stage_block(
+ block_id,
+ len(chunk_data),
+ chunk_data,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ return index, block_id
+
+ def _upload_substream_block(self, index, block_stream):
+ try:
+ block_id = f'BlockId{(index//self.chunk_size):05}'
+ self.service.stage_block(
+ block_id,
+ len(block_stream),
+ block_stream,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+ return block_id
+
+
+class PageBlobChunkUploader(_ChunkUploader):
+
+ def _is_chunk_empty(self, chunk_data):
+ # read until non-zero byte is encountered
+ # if reached the end without returning, then chunk_data is all 0's
+ return not any(bytearray(chunk_data))
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ if not self._is_chunk_empty(chunk_data):
+ chunk_end = chunk_offset + len(chunk_data) - 1
+ content_range = f"bytes={chunk_offset}-{chunk_end}"
+ computed_md5 = None
+ self.response_headers = self.service.upload_pages(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ transactional_content_md5=computed_md5,
+ range=content_range,
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AppendBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ if self.current_length is None:
+ self.response_headers = self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ self.current_length = int(self.response_headers["blob_append_offset"])
+ else:
+ self.request_options['append_position_access_conditions'].append_position = \
+ self.current_length + chunk_offset
+ self.response_headers = self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class DataLakeFileChunkUploader(_ChunkUploader):
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ self.response_headers = self.service.append_data(
+ body=chunk_data,
+ position=chunk_offset,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ def _upload_substream_block(self, index, block_stream):
+ try:
+ self.service.append_data(
+ body=block_stream,
+ position=index,
+ content_length=len(block_stream),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+
+
+class FileChunkUploader(_ChunkUploader):
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ length = len(chunk_data)
+ chunk_end = chunk_offset + length - 1
+ response = self.service.upload_range(
+ chunk_data,
+ chunk_offset,
+ length,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ return f'bytes={chunk_offset}-{chunk_end}', response
+
+ # TODO: Implement this method.
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class SubStream(IOBase):
+
+ def __init__(self, wrapped_stream, stream_begin_index, length, lockObj):
+ # Python 2.7: file-like objects created with open() typically support seek(), but are not
+ # derivations of io.IOBase and thus do not implement seekable().
+ # Python > 3.0: file-like objects created with open() are derived from io.IOBase.
+ try:
+ # only the main thread runs this, so there's no need grabbing the lock
+ wrapped_stream.seek(0, SEEK_CUR)
+ except Exception as exc:
+ raise ValueError("Wrapped stream must support seek().") from exc
+
+ self._lock = lockObj
+ self._wrapped_stream = wrapped_stream
+ self._position = 0
+ self._stream_begin_index = stream_begin_index
+ self._length = length
+ self._buffer = BytesIO()
+
+ # we must avoid buffering more than necessary, and also not use up too much memory
+ # so the max buffer size is capped at 4MB
+ self._max_buffer_size = (
+ length if length < _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE
+ )
+ self._current_buffer_start = 0
+ self._current_buffer_size = 0
+ super(SubStream, self).__init__()
+
+ def __len__(self):
+ return self._length
+
+ def close(self):
+ if self._buffer:
+ self._buffer.close()
+ self._wrapped_stream = None
+ IOBase.close(self)
+
+ def fileno(self):
+ return self._wrapped_stream.fileno()
+
+ def flush(self):
+ pass
+
+ def read(self, size=None):
+ if self.closed: # pylint: disable=using-constant-test
+ raise ValueError("Stream is closed.")
+
+ if size is None:
+ size = self._length - self._position
+
+ # adjust if out of bounds
+ if size + self._position >= self._length:
+ size = self._length - self._position
+
+ # return fast
+ if size == 0 or self._buffer.closed:
+ return b""
+
+ # attempt first read from the read buffer and update position
+ read_buffer = self._buffer.read(size)
+ bytes_read = len(read_buffer)
+ bytes_remaining = size - bytes_read
+ self._position += bytes_read
+
+ # repopulate the read buffer from the underlying stream to fulfill the request
+ # ensure the seek and read operations are done atomically (only if a lock is provided)
+ if bytes_remaining > 0:
+ with self._buffer:
+ # either read in the max buffer size specified on the class
+ # or read in just enough data for the current block/sub stream
+ current_max_buffer_size = min(self._max_buffer_size, self._length - self._position)
+
+ # lock is only defined if max_concurrency > 1 (parallel uploads)
+ if self._lock:
+ with self._lock:
+ # reposition the underlying stream to match the start of the data to read
+ absolute_position = self._stream_begin_index + self._position
+ self._wrapped_stream.seek(absolute_position, SEEK_SET)
+ # If we can't seek to the right location, our read will be corrupted so fail fast.
+ if self._wrapped_stream.tell() != absolute_position:
+ raise IOError("Stream failed to seek to the desired location.")
+ buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
+ else:
+ absolute_position = self._stream_begin_index + self._position
+ # It's possible that there's connection problem during data transfer,
+ # so when we retry we don't want to read from current position of wrapped stream,
+ # instead we should seek to where we want to read from.
+ if self._wrapped_stream.tell() != absolute_position:
+ self._wrapped_stream.seek(absolute_position, SEEK_SET)
+
+ buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
+
+ if buffer_from_stream:
+ # update the buffer with new data from the wrapped stream
+ # we need to note down the start position and size of the buffer, in case seek is performed later
+ self._buffer = BytesIO(buffer_from_stream)
+ self._current_buffer_start = self._position
+ self._current_buffer_size = len(buffer_from_stream)
+
+ # read the remaining bytes from the new buffer and update position
+ second_read_buffer = self._buffer.read(bytes_remaining)
+ read_buffer += second_read_buffer
+ self._position += len(second_read_buffer)
+
+ return read_buffer
+
+ def readable(self):
+ return True
+
+ def readinto(self, b):
+ raise UnsupportedOperation
+
+ def seek(self, offset, whence=0):
+ if whence is SEEK_SET:
+ start_index = 0
+ elif whence is SEEK_CUR:
+ start_index = self._position
+ elif whence is SEEK_END:
+ start_index = self._length
+ offset = -offset
+ else:
+ raise ValueError("Invalid argument for the 'whence' parameter.")
+
+ pos = start_index + offset
+
+ if pos > self._length:
+ pos = self._length
+ elif pos < 0:
+ pos = 0
+
+ # check if buffer is still valid
+ # if not, drop buffer
+ if pos < self._current_buffer_start or pos >= self._current_buffer_start + self._current_buffer_size:
+ self._buffer.close()
+ self._buffer = BytesIO()
+ else: # if yes seek to correct position
+ delta = pos - self._current_buffer_start
+ self._buffer.seek(delta, SEEK_SET)
+
+ self._position = pos
+ return pos
+
+ def seekable(self):
+ return True
+
+ def tell(self):
+ return self._position
+
+ def write(self):
+ raise UnsupportedOperation
+
+ def writelines(self):
+ raise UnsupportedOperation
+
+ def writeable(self):
+ return False
+
+
+class IterStreamer(object):
+ """
+ File-like streaming iterator.
+ """
+
+ def __init__(self, generator, encoding="UTF-8"):
+ self.generator = generator
+ self.iterator = iter(generator)
+ self.leftover = b""
+ self.encoding = encoding
+
+ def __len__(self):
+ return self.generator.__len__()
+
+ def __iter__(self):
+ return self.iterator
+
+ def seekable(self):
+ return False
+
+ def __next__(self):
+ return next(self.iterator)
+
+ def tell(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator does not support tell.")
+
+ def seek(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator is not seekable.")
+
+ def read(self, size):
+ data = self.leftover
+ count = len(self.leftover)
+ try:
+ while count < size:
+ chunk = self.__next__()
+ if isinstance(chunk, str):
+ chunk = chunk.encode(self.encoding)
+ data += chunk
+ count += len(chunk)
+ # This means count < size and what's leftover will be returned in this call.
+ except StopIteration:
+ self.leftover = b""
+
+ if count >= size:
+ self.leftover = data[size:]
+
+ return data[:size]
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py
new file mode 100644
index 00000000..3e102ec5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/fileshare/_shared/uploads_async.py
@@ -0,0 +1,460 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import asyncio
+import inspect
+import threading
+from asyncio import Lock
+from io import UnsupportedOperation
+from itertools import islice
+from math import ceil
+from typing import AsyncGenerator, Union
+
+from .import encode_base64, url_quote
+from .request_handlers import get_length
+from .response_handlers import return_response_headers
+from .uploads import SubStream, IterStreamer # pylint: disable=unused-import
+
+
+async def _async_parallel_uploads(uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = await pending.__anext__()
+ running.add(asyncio.ensure_future(uploader(next_chunk)))
+ except StopAsyncIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ if running:
+ done, _running = await asyncio.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+async def _parallel_uploads(uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = next(pending)
+ running.add(asyncio.ensure_future(uploader(next_chunk)))
+ except StopIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ if running:
+ done, _running = await asyncio.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+async def upload_data_chunks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ upload_tasks = uploader.get_chunk_streams()
+ running_futures = []
+ for _ in range(max_concurrency):
+ try:
+ chunk = await upload_tasks.__anext__()
+ running_futures.append(asyncio.ensure_future(uploader.process_chunk(chunk)))
+ except StopAsyncIteration:
+ break
+
+ range_ids = await _async_parallel_uploads(uploader.process_chunk, upload_tasks, running_futures)
+ else:
+ range_ids = []
+ async for chunk in uploader.get_chunk_streams():
+ range_ids.append(await uploader.process_chunk(chunk))
+
+ if any(range_ids):
+ return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
+ return uploader.response_headers
+
+
+async def upload_substream_blocks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ upload_tasks = uploader.get_substream_blocks()
+ running_futures = [
+ asyncio.ensure_future(uploader.process_substream_block(u))
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures)
+ else:
+ range_ids = []
+ for block in uploader.get_substream_blocks():
+ range_ids.append(await uploader.process_substream_block(block))
+ if any(range_ids):
+ return sorted(range_ids)
+ return
+
+
+class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self, service,
+ total_size,
+ chunk_size,
+ stream,
+ parallel,
+ encryptor=None,
+ padder=None,
+ progress_hook=None,
+ **kwargs):
+ self.service = service
+ self.total_size = total_size
+ self.chunk_size = chunk_size
+ self.stream = stream
+ self.parallel = parallel
+
+ # Stream management
+ self.stream_lock = threading.Lock() if parallel else None
+
+ # Progress feedback
+ self.progress_total = 0
+ self.progress_lock = Lock() if parallel else None
+ self.progress_hook = progress_hook
+
+ # Encryption
+ self.encryptor = encryptor
+ self.padder = padder
+ self.response_headers = None
+ self.etag = None
+ self.last_modified = None
+ self.request_options = kwargs
+
+ async def get_chunk_streams(self):
+ index = 0
+ while True:
+ data = b''
+ read_size = self.chunk_size
+
+ # Buffer until we either reach the end of the stream or get a whole chunk.
+ while True:
+ if self.total_size:
+ read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
+ temp = self.stream.read(read_size)
+ if inspect.isawaitable(temp):
+ temp = await temp
+ if not isinstance(temp, bytes):
+ raise TypeError('Blob data should be of type bytes.')
+ data += temp or b""
+
+ # We have read an empty string and so are at the end
+ # of the buffer or we have read a full chunk.
+ if temp == b'' or len(data) == self.chunk_size:
+ break
+
+ if len(data) == self.chunk_size:
+ if self.padder:
+ data = self.padder.update(data)
+ if self.encryptor:
+ data = self.encryptor.update(data)
+ yield index, data
+ else:
+ if self.padder:
+ data = self.padder.update(data) + self.padder.finalize()
+ if self.encryptor:
+ data = self.encryptor.update(data) + self.encryptor.finalize()
+ if data:
+ yield index, data
+ break
+ index += len(data)
+
+ async def process_chunk(self, chunk_data):
+ chunk_bytes = chunk_data[1]
+ chunk_offset = chunk_data[0]
+ return await self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
+
+ async def _update_progress(self, length):
+ if self.progress_lock is not None:
+ async with self.progress_lock:
+ self.progress_total += length
+ else:
+ self.progress_total += length
+
+ if self.progress_hook:
+ await self.progress_hook(self.progress_total, self.total_size)
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ async def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
+ range_id = await self._upload_chunk(chunk_offset, chunk_data)
+ await self._update_progress(len(chunk_data))
+ return range_id
+
+ def get_substream_blocks(self):
+ assert self.chunk_size is not None
+ lock = self.stream_lock
+ blob_length = self.total_size
+
+ if blob_length is None:
+ blob_length = get_length(self.stream)
+ if blob_length is None:
+ raise ValueError("Unable to determine content length of upload data.")
+
+ blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
+ last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
+
+ for i in range(blocks):
+ index = i * self.chunk_size
+ length = last_block_size if i == blocks - 1 else self.chunk_size
+ yield index, SubStream(self.stream, index, length, lock)
+
+ async def process_substream_block(self, block_data):
+ return await self._upload_substream_block_with_progress(block_data[0], block_data[1])
+
+ async def _upload_substream_block(self, index, block_stream):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ async def _upload_substream_block_with_progress(self, index, block_stream):
+ range_id = await self._upload_substream_block(index, block_stream)
+ await self._update_progress(len(block_stream))
+ return range_id
+
+ def set_response_properties(self, resp):
+ self.etag = resp.etag
+ self.last_modified = resp.last_modified
+
+
+class BlockBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ kwargs.pop('modified_access_conditions', None)
+ super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ # TODO: This is incorrect, but works with recording.
+ index = f'{chunk_offset:032d}'
+ block_id = encode_base64(url_quote(encode_base64(index)))
+ await self.service.stage_block(
+ block_id,
+ len(chunk_data),
+ body=chunk_data,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ return index, block_id
+
+ async def _upload_substream_block(self, index, block_stream):
+ try:
+ block_id = f'BlockId{(index//self.chunk_size):05}'
+ await self.service.stage_block(
+ block_id,
+ len(block_stream),
+ block_stream,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ finally:
+ block_stream.close()
+ return block_id
+
+
+class PageBlobChunkUploader(_ChunkUploader):
+
+ def _is_chunk_empty(self, chunk_data):
+ # read until non-zero byte is encountered
+ # if reached the end without returning, then chunk_data is all 0's
+ for each_byte in chunk_data:
+ if each_byte not in [0, b'\x00']:
+ return False
+ return True
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ if not self._is_chunk_empty(chunk_data):
+ chunk_end = chunk_offset + len(chunk_data) - 1
+ content_range = f'bytes={chunk_offset}-{chunk_end}'
+ computed_md5 = None
+ self.response_headers = await self.service.upload_pages(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ transactional_content_md5=computed_md5,
+ range=content_range,
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AppendBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ if self.current_length is None:
+ self.response_headers = await self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ self.current_length = int(self.response_headers['blob_append_offset'])
+ else:
+ self.request_options['append_position_access_conditions'].append_position = \
+ self.current_length + chunk_offset
+ self.response_headers = await self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class DataLakeFileChunkUploader(_ChunkUploader):
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ self.response_headers = await self.service.append_data(
+ body=chunk_data,
+ position=chunk_offset,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ async def _upload_substream_block(self, index, block_stream):
+ try:
+ await self.service.append_data(
+ body=block_stream,
+ position=index,
+ content_length=len(block_stream),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+
+
+class FileChunkUploader(_ChunkUploader):
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ length = len(chunk_data)
+ chunk_end = chunk_offset + length - 1
+ response = await self.service.upload_range(
+ chunk_data,
+ chunk_offset,
+ length,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ range_id = f'bytes={chunk_offset}-{chunk_end}'
+ return range_id, response
+
+ # TODO: Implement this method.
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AsyncIterStreamer():
+ """
+ File-like streaming object for AsyncGenerators.
+ """
+ def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"):
+ self.iterator = generator.__aiter__()
+ self.leftover = b""
+ self.encoding = encoding
+
+ def seekable(self):
+ return False
+
+ def tell(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator does not support tell.")
+
+ def seek(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator is not seekable.")
+
+ async def read(self, size: int) -> bytes:
+ data = self.leftover
+ count = len(self.leftover)
+ try:
+ while count < size:
+ chunk = await self.iterator.__anext__()
+ if isinstance(chunk, str):
+ chunk = chunk.encode(self.encoding)
+ data += chunk
+ count += len(chunk)
+ # This means count < size and what's leftover will be returned in this call.
+ except StopAsyncIteration:
+ self.leftover = b""
+
+ if count >= size:
+ self.leftover = data[size:]
+
+ return data[:size]