aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/blob/_shared')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/__init__.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/authentication.py245
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io.py435
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io_async.py419
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile.py257
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile_async.py210
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/schema.py1178
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client.py458
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client_async.py280
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/constants.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/models.py585
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/parser.py53
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies.py694
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies_async.py296
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/request_handlers.py270
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/response_handlers.py200
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/shared_access_signature.py252
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads.py604
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads_async.py460
20 files changed, 6974 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/__init__.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/__init__.py
new file mode 100644
index 00000000..a8b1a27d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/__init__.py
@@ -0,0 +1,54 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import base64
+import hashlib
+import hmac
+
+try:
+ from urllib.parse import quote, unquote
+except ImportError:
+ from urllib2 import quote, unquote # type: ignore
+
+
+def url_quote(url):
+ return quote(url)
+
+
+def url_unquote(url):
+ return unquote(url)
+
+
+def encode_base64(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ encoded = base64.b64encode(data)
+ return encoded.decode('utf-8')
+
+
+def decode_base64_to_bytes(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ return base64.b64decode(data)
+
+
+def decode_base64_to_text(data):
+ decoded_bytes = decode_base64_to_bytes(data)
+ return decoded_bytes.decode('utf-8')
+
+
+def sign_string(key, string_to_sign, key_is_base64=True):
+ if key_is_base64:
+ key = decode_base64_to_bytes(key)
+ else:
+ if isinstance(key, str):
+ key = key.encode('utf-8')
+ if isinstance(string_to_sign, str):
+ string_to_sign = string_to_sign.encode('utf-8')
+ signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256)
+ digest = signed_hmac_sha256.digest()
+ encoded_digest = encode_base64(digest)
+ return encoded_digest
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/authentication.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/authentication.py
new file mode 100644
index 00000000..b41f2391
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/authentication.py
@@ -0,0 +1,245 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import logging
+import re
+from typing import List, Tuple
+from urllib.parse import unquote, urlparse
+from functools import cmp_to_key
+
+try:
+ from yarl import URL
+except ImportError:
+ pass
+
+try:
+ from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import
+except ImportError:
+ AioHttpTransport = None
+
+from azure.core.exceptions import ClientAuthenticationError
+from azure.core.pipeline.policies import SansIOHTTPPolicy
+
+from . import sign_string
+
+logger = logging.getLogger(__name__)
+
+
+table_lv0 = [
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x71c, 0x0, 0x71f, 0x721, 0x723, 0x725,
+ 0x0, 0x0, 0x0, 0x72d, 0x803, 0x0, 0x0, 0x733, 0x0, 0xd03, 0xd1a, 0xd1c, 0xd1e,
+ 0xd20, 0xd22, 0xd24, 0xd26, 0xd28, 0xd2a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25, 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51,
+ 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99, 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9,
+ 0x0, 0x0, 0x0, 0x743, 0x744, 0x748, 0xe02, 0xe09, 0xe0a, 0xe1a, 0xe21, 0xe23, 0xe25,
+ 0xe2c, 0xe32, 0xe35, 0xe36, 0xe48, 0xe51, 0xe70, 0xe7c, 0xe7e, 0xe89, 0xe8a, 0xe91, 0xe99,
+ 0xe9f, 0xea2, 0xea4, 0xea6, 0xea7, 0xea9, 0x0, 0x74c, 0x0, 0x750, 0x0,
+]
+
+table_lv4 = [
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8012, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8212, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+]
+
+def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements
+ tables = [table_lv0, table_lv4]
+ curr_level, i, j, n = 0, 0, 0, len(tables)
+ lhs_len = len(lhs)
+ rhs_len = len(rhs)
+ while curr_level < n:
+ if curr_level == (n - 1) and i != j:
+ if i > j:
+ return -1
+ if i < j:
+ return 1
+ return 0
+
+ w1 = tables[curr_level][ord(lhs[i])] if i < lhs_len else 0x1
+ w2 = tables[curr_level][ord(rhs[j])] if j < rhs_len else 0x1
+
+ if w1 == 0x1 and w2 == 0x1:
+ i = 0
+ j = 0
+ curr_level += 1
+ elif w1 == w2:
+ i += 1
+ j += 1
+ elif w1 == 0:
+ i += 1
+ elif w2 == 0:
+ j += 1
+ else:
+ if w1 < w2:
+ return -1
+ if w1 > w2:
+ return 1
+ return 0
+ return 0
+
+
+# wraps a given exception with the desired exception type
+def _wrap_exception(ex, desired_type):
+ msg = ""
+ if ex.args:
+ msg = ex.args[0]
+ return desired_type(msg)
+
+# This method attempts to emulate the sorting done by the service
+def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
+
+ # Build dict of tuples and list of keys
+ header_dict = {}
+ header_keys = []
+ for k, v in input_headers:
+ header_dict[k] = v
+ header_keys.append(k)
+
+ try:
+ header_keys = sorted(header_keys, key=cmp_to_key(compare))
+ except ValueError as exc:
+ raise ValueError("Illegal character encountered when sorting headers.") from exc
+
+ # Build list of sorted tuples
+ sorted_headers = []
+ for key in header_keys:
+ sorted_headers.append((key, header_dict.pop(key)))
+ return sorted_headers
+
+
+class AzureSigningError(ClientAuthenticationError):
+ """
+ Represents a fatal error when attempting to sign a request.
+ In general, the cause of this exception is user error. For example, the given account key is not valid.
+ Please visit https://learn.microsoft.com/azure/storage/common/storage-create-storage-account for more info.
+ """
+
+
+class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
+
+ def __init__(self, account_name, account_key):
+ self.account_name = account_name
+ self.account_key = account_key
+ super(SharedKeyCredentialPolicy, self).__init__()
+
+ @staticmethod
+ def _get_headers(request, headers_to_sign):
+ headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value)
+ if 'content-length' in headers and headers['content-length'] == '0':
+ del headers['content-length']
+ return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n'
+
+ @staticmethod
+ def _get_verb(request):
+ return request.http_request.method + '\n'
+
+ def _get_canonicalized_resource(self, request):
+ uri_path = urlparse(request.http_request.url).path
+ try:
+ if isinstance(request.context.transport, AioHttpTransport) or \
+ isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \
+ isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None),
+ AioHttpTransport):
+ uri_path = URL(uri_path)
+ return '/' + self.account_name + str(uri_path)
+ except TypeError:
+ pass
+ return '/' + self.account_name + uri_path
+
+ @staticmethod
+ def _get_canonicalized_headers(request):
+ string_to_sign = ''
+ x_ms_headers = []
+ for name, value in request.http_request.headers.items():
+ if name.startswith('x-ms-'):
+ x_ms_headers.append((name.lower(), value))
+ x_ms_headers = _storage_header_sort(x_ms_headers)
+ for name, value in x_ms_headers:
+ if value is not None:
+ string_to_sign += ''.join([name, ':', value, '\n'])
+ return string_to_sign
+
+ @staticmethod
+ def _get_canonicalized_resource_query(request):
+ sorted_queries = list(request.http_request.query.items())
+ sorted_queries.sort()
+
+ string_to_sign = ''
+ for name, value in sorted_queries:
+ if value is not None:
+ string_to_sign += '\n' + name.lower() + ':' + unquote(value)
+
+ return string_to_sign
+
+ def _add_authorization_header(self, request, string_to_sign):
+ try:
+ signature = sign_string(self.account_key, string_to_sign)
+ auth_string = 'SharedKey ' + self.account_name + ':' + signature
+ request.http_request.headers['Authorization'] = auth_string
+ except Exception as ex:
+ # Wrap any error that occurred as signing error
+ # Doing so will clarify/locate the source of problem
+ raise _wrap_exception(ex, AzureSigningError) from ex
+
+ def on_request(self, request):
+ string_to_sign = \
+ self._get_verb(request) + \
+ self._get_headers(
+ request,
+ [
+ 'content-encoding', 'content-language', 'content-length',
+ 'content-md5', 'content-type', 'date', 'if-modified-since',
+ 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range'
+ ]
+ ) + \
+ self._get_canonicalized_headers(request) + \
+ self._get_canonicalized_resource(request) + \
+ self._get_canonicalized_resource_query(request)
+
+ self._add_authorization_header(request, string_to_sign)
+ # logger.debug("String_to_sign=%s", string_to_sign)
+
+
+class StorageHttpChallenge(object):
+ def __init__(self, challenge):
+ """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """
+ if not challenge:
+ raise ValueError("Challenge cannot be empty")
+
+ self._parameters = {}
+ self.scheme, trimmed_challenge = challenge.strip().split(" ", 1)
+
+ # name=value pairs either comma or space separated with values possibly being
+ # enclosed in quotes
+ for item in re.split('[, ]', trimmed_challenge):
+ comps = item.split("=")
+ if len(comps) == 2:
+ key = comps[0].strip(' "')
+ value = comps[1].strip(' "')
+ if key:
+ self._parameters[key] = value
+
+ # Extract and verify required parameters
+ self.authorization_uri = self._parameters.get('authorization_uri')
+ if not self.authorization_uri:
+ raise ValueError("Authorization Uri not found")
+
+ self.resource_id = self._parameters.get('resource_id')
+ if not self.resource_id:
+ raise ValueError("Resource id not found")
+
+ uri_path = urlparse(self.authorization_uri).path.lstrip("/")
+ self.tenant_id = uri_path.split("/")[0]
+
+ def get_value(self, key):
+ return self._parameters.get(key)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/__init__.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/__init__.py
new file mode 100644
index 00000000..5b396cd2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/__init__.py
@@ -0,0 +1,5 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io.py
new file mode 100644
index 00000000..3e46f1fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io.py
@@ -0,0 +1,435 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-missing-return, docstring-missing-rtype
+
+"""Input/output utilities.
+
+Includes:
+ - i/o-specific constants
+ - i/o-specific exceptions
+ - schema validation
+ - leaf value encoding and decoding
+ - datum reader/writer stuff (?)
+
+Also includes a generic representation for data, which uses the
+following mapping:
+ - Schema records are implemented as dict.
+ - Schema arrays are implemented as list.
+ - Schema maps are implemented as dict.
+ - Schema strings are implemented as unicode.
+ - Schema bytes are implemented as str.
+ - Schema ints are implemented as int.
+ - Schema longs are implemented as long.
+ - Schema floats are implemented as float.
+ - Schema doubles are implemented as float.
+ - Schema booleans are implemented as bool.
+"""
+
+import json
+import logging
+import struct
+import sys
+
+from ..avro import schema
+
+PY3 = sys.version_info[0] == 3
+
+logger = logging.getLogger(__name__)
+
+# ------------------------------------------------------------------------------
+# Constants
+
+STRUCT_FLOAT = struct.Struct('<f') # little-endian float
+STRUCT_DOUBLE = struct.Struct('<d') # little-endian double
+
+# ------------------------------------------------------------------------------
+# Exceptions
+
+
+class SchemaResolutionException(schema.AvroException):
+ def __init__(self, fail_msg, writer_schema=None):
+ pretty_writers = json.dumps(json.loads(str(writer_schema)), indent=2)
+ if writer_schema:
+ fail_msg += f"\nWriter's Schema: {pretty_writers}"
+ schema.AvroException.__init__(self, fail_msg)
+
+# ------------------------------------------------------------------------------
+# Decoder
+
+
+class BinaryDecoder(object):
+ """Read leaf values."""
+
+ def __init__(self, reader):
+ """
+ reader is a Python object on which we can call read, seek, and tell.
+ """
+ self._reader = reader
+
+ @property
+ def reader(self):
+ """Reports the reader used by this decoder."""
+ return self._reader
+
+ def read(self, n):
+ """Read n bytes.
+
+ :param int n: Number of bytes to read.
+ :returns: The next n bytes from the input.
+ :rtype: bytes
+ """
+ assert (n >= 0), n
+ input_bytes = self.reader.read(n)
+ if n > 0 and not input_bytes:
+ raise StopIteration
+ assert (len(input_bytes) == n), input_bytes
+ return input_bytes
+
+ @staticmethod
+ def read_null():
+ """
+ null is written as zero bytes
+ """
+ return None
+
+ def read_boolean(self):
+ """
+ a boolean is written as a single byte
+ whose value is either 0 (false) or 1 (true).
+ """
+ b = ord(self.read(1))
+ if b == 1:
+ return True
+ if b == 0:
+ return False
+ fail_msg = f"Invalid value for boolean: {b}"
+ raise schema.AvroException(fail_msg)
+
+ def read_int(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ return self.read_long()
+
+ def read_long(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ b = ord(self.read(1))
+ n = b & 0x7F
+ shift = 7
+ while (b & 0x80) != 0:
+ b = ord(self.read(1))
+ n |= (b & 0x7F) << shift
+ shift += 7
+ datum = (n >> 1) ^ -(n & 1)
+ return datum
+
+ def read_float(self):
+ """
+ A float is written as 4 bytes.
+ The float is converted into a 32-bit integer using a method equivalent to
+ Java's floatToIntBits and then encoded in little-endian format.
+ """
+ return STRUCT_FLOAT.unpack(self.read(4))[0]
+
+ def read_double(self):
+ """
+ A double is written as 8 bytes.
+ The double is converted into a 64-bit integer using a method equivalent to
+ Java's doubleToLongBits and then encoded in little-endian format.
+ """
+ return STRUCT_DOUBLE.unpack(self.read(8))[0]
+
+ def read_bytes(self):
+ """
+ Bytes are encoded as a long followed by that many bytes of data.
+ """
+ nbytes = self.read_long()
+ assert (nbytes >= 0), nbytes
+ return self.read(nbytes)
+
+ def read_utf8(self):
+ """
+ A string is encoded as a long followed by
+ that many bytes of UTF-8 encoded character data.
+ """
+ input_bytes = self.read_bytes()
+ if PY3:
+ try:
+ return input_bytes.decode('utf-8')
+ except UnicodeDecodeError as exn:
+ logger.error('Invalid UTF-8 input bytes: %r', input_bytes)
+ raise exn
+ else:
+ # PY2
+ return unicode(input_bytes, "utf-8") # pylint: disable=undefined-variable
+
+ def skip_null(self):
+ pass
+
+ def skip_boolean(self):
+ self.skip(1)
+
+ def skip_int(self):
+ self.skip_long()
+
+ def skip_long(self):
+ b = ord(self.read(1))
+ while (b & 0x80) != 0:
+ b = ord(self.read(1))
+
+ def skip_float(self):
+ self.skip(4)
+
+ def skip_double(self):
+ self.skip(8)
+
+ def skip_bytes(self):
+ self.skip(self.read_long())
+
+ def skip_utf8(self):
+ self.skip_bytes()
+
+ def skip(self, n):
+ self.reader.seek(self.reader.tell() + n)
+
+
+# ------------------------------------------------------------------------------
+# DatumReader
+
+
+class DatumReader(object):
+ """Deserialize Avro-encoded data into a Python data structure."""
+
+ def __init__(self, writer_schema=None):
+ """
+ As defined in the Avro specification, we call the schema encoded
+ in the data the "writer's schema".
+ """
+ self._writer_schema = writer_schema
+
+ # read/write properties
+ def set_writer_schema(self, writer_schema):
+ self._writer_schema = writer_schema
+
+ writer_schema = property(lambda self: self._writer_schema,
+ set_writer_schema)
+
+ def read(self, decoder):
+ return self.read_data(self.writer_schema, decoder)
+
+ def read_data(self, writer_schema, decoder):
+ # function dispatch for reading data based on type of writer's schema
+ if writer_schema.type == 'null':
+ result = decoder.read_null()
+ elif writer_schema.type == 'boolean':
+ result = decoder.read_boolean()
+ elif writer_schema.type == 'string':
+ result = decoder.read_utf8()
+ elif writer_schema.type == 'int':
+ result = decoder.read_int()
+ elif writer_schema.type == 'long':
+ result = decoder.read_long()
+ elif writer_schema.type == 'float':
+ result = decoder.read_float()
+ elif writer_schema.type == 'double':
+ result = decoder.read_double()
+ elif writer_schema.type == 'bytes':
+ result = decoder.read_bytes()
+ elif writer_schema.type == 'fixed':
+ result = self.read_fixed(writer_schema, decoder)
+ elif writer_schema.type == 'enum':
+ result = self.read_enum(writer_schema, decoder)
+ elif writer_schema.type == 'array':
+ result = self.read_array(writer_schema, decoder)
+ elif writer_schema.type == 'map':
+ result = self.read_map(writer_schema, decoder)
+ elif writer_schema.type in ['union', 'error_union']:
+ result = self.read_union(writer_schema, decoder)
+ elif writer_schema.type in ['record', 'error', 'request']:
+ result = self.read_record(writer_schema, decoder)
+ else:
+ fail_msg = f"Cannot read unknown schema type: {writer_schema.type}"
+ raise schema.AvroException(fail_msg)
+ return result
+
+ def skip_data(self, writer_schema, decoder):
+ if writer_schema.type == 'null':
+ result = decoder.skip_null()
+ elif writer_schema.type == 'boolean':
+ result = decoder.skip_boolean()
+ elif writer_schema.type == 'string':
+ result = decoder.skip_utf8()
+ elif writer_schema.type == 'int':
+ result = decoder.skip_int()
+ elif writer_schema.type == 'long':
+ result = decoder.skip_long()
+ elif writer_schema.type == 'float':
+ result = decoder.skip_float()
+ elif writer_schema.type == 'double':
+ result = decoder.skip_double()
+ elif writer_schema.type == 'bytes':
+ result = decoder.skip_bytes()
+ elif writer_schema.type == 'fixed':
+ result = self.skip_fixed(writer_schema, decoder)
+ elif writer_schema.type == 'enum':
+ result = self.skip_enum(decoder)
+ elif writer_schema.type == 'array':
+ self.skip_array(writer_schema, decoder)
+ result = None
+ elif writer_schema.type == 'map':
+ self.skip_map(writer_schema, decoder)
+ result = None
+ elif writer_schema.type in ['union', 'error_union']:
+ result = self.skip_union(writer_schema, decoder)
+ elif writer_schema.type in ['record', 'error', 'request']:
+ self.skip_record(writer_schema, decoder)
+ result = None
+ else:
+ fail_msg = f"Unknown schema type: {writer_schema.type}"
+ raise schema.AvroException(fail_msg)
+ return result
+
+ # Fixed instances are encoded using the number of bytes declared in the schema.
+ @staticmethod
+ def read_fixed(writer_schema, decoder):
+ return decoder.read(writer_schema.size)
+
+ @staticmethod
+ def skip_fixed(writer_schema, decoder):
+ return decoder.skip(writer_schema.size)
+
+ # An enum is encoded by a int, representing the zero-based position of the symbol in the schema.
+ @staticmethod
+ def read_enum(writer_schema, decoder):
+ # read data
+ index_of_symbol = decoder.read_int()
+ if index_of_symbol >= len(writer_schema.symbols):
+ fail_msg = f"Can't access enum index {index_of_symbol} for enum with {len(writer_schema.symbols)} symbols"
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ read_symbol = writer_schema.symbols[index_of_symbol]
+ return read_symbol
+
+ @staticmethod
+ def skip_enum(decoder):
+ return decoder.skip_int()
+
+ # Arrays are encoded as a series of blocks.
+
+ # Each block consists of a long count value, followed by that many array items.
+ # A block with count zero indicates the end of the array. Each item is encoded per the array's item schema.
+
+ # If a block's count is negative, then the count is followed immediately by a long block size,
+ # indicating the number of bytes in the block.
+ # The actual count in this case is the absolute value of the count written.
+ def read_array(self, writer_schema, decoder):
+ read_items = []
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ decoder.read_long()
+ for _ in range(block_count):
+ read_items.append(self.read_data(writer_schema.items, decoder))
+ block_count = decoder.read_long()
+ return read_items
+
+ def skip_array(self, writer_schema, decoder):
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = decoder.read_long()
+ decoder.skip(block_size)
+ else:
+ for _ in range(block_count):
+ self.skip_data(writer_schema.items, decoder)
+ block_count = decoder.read_long()
+
+ # Maps are encoded as a series of blocks.
+
+ # Each block consists of a long count value, followed by that many key/value pairs.
+ # A block with count zero indicates the end of the map. Each item is encoded per the map's value schema.
+
+ # If a block's count is negative, then the count is followed immediately by a long block size,
+ # indicating the number of bytes in the block.
+ # The actual count in this case is the absolute value of the count written.
+ def read_map(self, writer_schema, decoder):
+ read_items = {}
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ decoder.read_long()
+ for _ in range(block_count):
+ key = decoder.read_utf8()
+ read_items[key] = self.read_data(writer_schema.values, decoder)
+ block_count = decoder.read_long()
+ return read_items
+
+ def skip_map(self, writer_schema, decoder):
+ block_count = decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = decoder.read_long()
+ decoder.skip(block_size)
+ else:
+ for _ in range(block_count):
+ decoder.skip_utf8()
+ self.skip_data(writer_schema.values, decoder)
+ block_count = decoder.read_long()
+
+ # A union is encoded by first writing a long value indicating
+ # the zero-based position within the union of the schema of its value.
+ # The value is then encoded per the indicated schema within the union.
+ def read_union(self, writer_schema, decoder):
+ # schema resolution
+ index_of_schema = int(decoder.read_long())
+ if index_of_schema >= len(writer_schema.schemas):
+ fail_msg = (f"Can't access branch index {index_of_schema} "
+ f"for union with {len(writer_schema.schemas)} branches")
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ selected_writer_schema = writer_schema.schemas[index_of_schema]
+
+ # read data
+ return self.read_data(selected_writer_schema, decoder)
+
+ def skip_union(self, writer_schema, decoder):
+ index_of_schema = int(decoder.read_long())
+ if index_of_schema >= len(writer_schema.schemas):
+ fail_msg = (f"Can't access branch index {index_of_schema} "
+ f"for union with {len(writer_schema.schemas)} branches")
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ return self.skip_data(writer_schema.schemas[index_of_schema], decoder)
+
+ # A record is encoded by encoding the values of its fields
+ # in the order that they are declared. In other words, a record
+ # is encoded as just the concatenation of the encodings of its fields.
+ # Field values are encoded per their schema.
+
+ # Schema Resolution:
+ # * the ordering of fields may be different: fields are matched by name.
+ # * schemas for fields with the same name in both records are resolved
+ # recursively.
+ # * if the writer's record contains a field with a name not present in the
+ # reader's record, the writer's value for that field is ignored.
+ # * if the reader's record schema has a field that contains a default value,
+ # and writer's schema does not have a field with the same name, then the
+ # reader should use the default value from its field.
+ # * if the reader's record schema has a field with no default value, and
+ # writer's schema does not have a field with the same name, then the
+ # field's value is unset.
+ def read_record(self, writer_schema, decoder):
+ # schema resolution
+ read_record = {}
+ for field in writer_schema.fields:
+ field_val = self.read_data(field.type, decoder)
+ read_record[field.name] = field_val
+ return read_record
+
+ def skip_record(self, writer_schema, decoder):
+ for field in writer_schema.fields:
+ self.skip_data(field.type, decoder)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io_async.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io_async.py
new file mode 100644
index 00000000..8688661b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/avro_io_async.py
@@ -0,0 +1,419 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-missing-return, docstring-missing-rtype
+
+"""Input/output utilities.
+
+Includes:
+ - i/o-specific constants
+ - i/o-specific exceptions
+ - schema validation
+ - leaf value encoding and decoding
+ - datum reader/writer stuff (?)
+
+Also includes a generic representation for data, which uses the
+following mapping:
+ - Schema records are implemented as dict.
+ - Schema arrays are implemented as list.
+ - Schema maps are implemented as dict.
+ - Schema strings are implemented as unicode.
+ - Schema bytes are implemented as str.
+ - Schema ints are implemented as int.
+ - Schema longs are implemented as long.
+ - Schema floats are implemented as float.
+ - Schema doubles are implemented as float.
+ - Schema booleans are implemented as bool.
+"""
+
+import logging
+import sys
+
+from ..avro import schema
+
+from .avro_io import STRUCT_FLOAT, STRUCT_DOUBLE, SchemaResolutionException
+
+PY3 = sys.version_info[0] == 3
+
+logger = logging.getLogger(__name__)
+
+# ------------------------------------------------------------------------------
+# Decoder
+
+
+class AsyncBinaryDecoder(object):
+ """Read leaf values."""
+
+ def __init__(self, reader):
+ """
+ reader is a Python object on which we can call read, seek, and tell.
+ """
+ self._reader = reader
+
+ @property
+ def reader(self):
+ """Reports the reader used by this decoder."""
+ return self._reader
+
+ async def read(self, n):
+ """Read n bytes.
+
+ :param int n: Number of bytes to read.
+ :returns: The next n bytes from the input.
+ :rtype: bytes
+ """
+ assert (n >= 0), n
+ input_bytes = await self.reader.read(n)
+ if n > 0 and not input_bytes:
+ raise StopAsyncIteration
+ assert (len(input_bytes) == n), input_bytes
+ return input_bytes
+
+ @staticmethod
+ def read_null():
+ """
+ null is written as zero bytes
+ """
+ return None
+
+ async def read_boolean(self):
+ """
+ a boolean is written as a single byte
+ whose value is either 0 (false) or 1 (true).
+ """
+ b = ord(await self.read(1))
+ if b == 1:
+ return True
+ if b == 0:
+ return False
+ fail_msg = f"Invalid value for boolean: {b}"
+ raise schema.AvroException(fail_msg)
+
+ async def read_int(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ return await self.read_long()
+
+ async def read_long(self):
+ """
+ int and long values are written using variable-length, zig-zag coding.
+ """
+ b = ord(await self.read(1))
+ n = b & 0x7F
+ shift = 7
+ while (b & 0x80) != 0:
+ b = ord(await self.read(1))
+ n |= (b & 0x7F) << shift
+ shift += 7
+ datum = (n >> 1) ^ -(n & 1)
+ return datum
+
+ async def read_float(self):
+ """
+ A float is written as 4 bytes.
+ The float is converted into a 32-bit integer using a method equivalent to
+ Java's floatToIntBits and then encoded in little-endian format.
+ """
+ return STRUCT_FLOAT.unpack(await self.read(4))[0]
+
+ async def read_double(self):
+ """
+ A double is written as 8 bytes.
+ The double is converted into a 64-bit integer using a method equivalent to
+ Java's doubleToLongBits and then encoded in little-endian format.
+ """
+ return STRUCT_DOUBLE.unpack(await self.read(8))[0]
+
+ async def read_bytes(self):
+ """
+ Bytes are encoded as a long followed by that many bytes of data.
+ """
+ nbytes = await self.read_long()
+ assert (nbytes >= 0), nbytes
+ return await self.read(nbytes)
+
+ async def read_utf8(self):
+ """
+ A string is encoded as a long followed by
+ that many bytes of UTF-8 encoded character data.
+ """
+ input_bytes = await self.read_bytes()
+ if PY3:
+ try:
+ return input_bytes.decode('utf-8')
+ except UnicodeDecodeError as exn:
+ logger.error('Invalid UTF-8 input bytes: %r', input_bytes)
+ raise exn
+ else:
+ # PY2
+ return unicode(input_bytes, "utf-8") # pylint: disable=undefined-variable
+
+ def skip_null(self):
+ pass
+
+ async def skip_boolean(self):
+ await self.skip(1)
+
+ async def skip_int(self):
+ await self.skip_long()
+
+ async def skip_long(self):
+ b = ord(await self.read(1))
+ while (b & 0x80) != 0:
+ b = ord(await self.read(1))
+
+ async def skip_float(self):
+ await self.skip(4)
+
+ async def skip_double(self):
+ await self.skip(8)
+
+ async def skip_bytes(self):
+ await self.skip(await self.read_long())
+
+ async def skip_utf8(self):
+ await self.skip_bytes()
+
+ async def skip(self, n):
+ await self.reader.seek(await self.reader.tell() + n)
+
+
+# ------------------------------------------------------------------------------
+# DatumReader
+
+
+class AsyncDatumReader(object):
+ """Deserialize Avro-encoded data into a Python data structure."""
+
+ def __init__(self, writer_schema=None):
+ """
+ As defined in the Avro specification, we call the schema encoded
+ in the data the "writer's schema", and the schema expected by the
+ reader the "reader's schema".
+ """
+ self._writer_schema = writer_schema
+
+ # read/write properties
+ def set_writer_schema(self, writer_schema):
+ self._writer_schema = writer_schema
+
+ writer_schema = property(lambda self: self._writer_schema,
+ set_writer_schema)
+
+ async def read(self, decoder):
+ return await self.read_data(self.writer_schema, decoder)
+
+ async def read_data(self, writer_schema, decoder):
+ # function dispatch for reading data based on type of writer's schema
+ if writer_schema.type == 'null':
+ result = decoder.read_null()
+ elif writer_schema.type == 'boolean':
+ result = await decoder.read_boolean()
+ elif writer_schema.type == 'string':
+ result = await decoder.read_utf8()
+ elif writer_schema.type == 'int':
+ result = await decoder.read_int()
+ elif writer_schema.type == 'long':
+ result = await decoder.read_long()
+ elif writer_schema.type == 'float':
+ result = await decoder.read_float()
+ elif writer_schema.type == 'double':
+ result = await decoder.read_double()
+ elif writer_schema.type == 'bytes':
+ result = await decoder.read_bytes()
+ elif writer_schema.type == 'fixed':
+ result = await self.read_fixed(writer_schema, decoder)
+ elif writer_schema.type == 'enum':
+ result = await self.read_enum(writer_schema, decoder)
+ elif writer_schema.type == 'array':
+ result = await self.read_array(writer_schema, decoder)
+ elif writer_schema.type == 'map':
+ result = await self.read_map(writer_schema, decoder)
+ elif writer_schema.type in ['union', 'error_union']:
+ result = await self.read_union(writer_schema, decoder)
+ elif writer_schema.type in ['record', 'error', 'request']:
+ result = await self.read_record(writer_schema, decoder)
+ else:
+ fail_msg = f"Cannot read unknown schema type: {writer_schema.type}"
+ raise schema.AvroException(fail_msg)
+ return result
+
+ async def skip_data(self, writer_schema, decoder):
+ if writer_schema.type == 'null':
+ result = decoder.skip_null()
+ elif writer_schema.type == 'boolean':
+ result = await decoder.skip_boolean()
+ elif writer_schema.type == 'string':
+ result = await decoder.skip_utf8()
+ elif writer_schema.type == 'int':
+ result = await decoder.skip_int()
+ elif writer_schema.type == 'long':
+ result = await decoder.skip_long()
+ elif writer_schema.type == 'float':
+ result = await decoder.skip_float()
+ elif writer_schema.type == 'double':
+ result = await decoder.skip_double()
+ elif writer_schema.type == 'bytes':
+ result = await decoder.skip_bytes()
+ elif writer_schema.type == 'fixed':
+ result = await self.skip_fixed(writer_schema, decoder)
+ elif writer_schema.type == 'enum':
+ result = await self.skip_enum(decoder)
+ elif writer_schema.type == 'array':
+ await self.skip_array(writer_schema, decoder)
+ result = None
+ elif writer_schema.type == 'map':
+ await self.skip_map(writer_schema, decoder)
+ result = None
+ elif writer_schema.type in ['union', 'error_union']:
+ result = await self.skip_union(writer_schema, decoder)
+ elif writer_schema.type in ['record', 'error', 'request']:
+ await self.skip_record(writer_schema, decoder)
+ result = None
+ else:
+ fail_msg = f"Unknown schema type: {writer_schema.type}"
+ raise schema.AvroException(fail_msg)
+ return result
+
+ # Fixed instances are encoded using the number of bytes declared in the schema.
+ @staticmethod
+ async def read_fixed(writer_schema, decoder):
+ return await decoder.read(writer_schema.size)
+
+ @staticmethod
+ async def skip_fixed(writer_schema, decoder):
+ return await decoder.skip(writer_schema.size)
+
+ # An enum is encoded by a int, representing the zero-based position of the symbol in the schema.
+ @staticmethod
+ async def read_enum(writer_schema, decoder):
+ # read data
+ index_of_symbol = await decoder.read_int()
+ if index_of_symbol >= len(writer_schema.symbols):
+ fail_msg = f"Can't access enum index {index_of_symbol} for enum with {len(writer_schema.symbols)} symbols"
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ read_symbol = writer_schema.symbols[index_of_symbol]
+ return read_symbol
+
+ @staticmethod
+ async def skip_enum(decoder):
+ return await decoder.skip_int()
+
+ # Arrays are encoded as a series of blocks.
+
+ # Each block consists of a long count value, followed by that many array items.
+ # A block with count zero indicates the end of the array. Each item is encoded per the array's item schema.
+
+ # If a block's count is negative, then the count is followed immediately by a long block size,
+ # indicating the number of bytes in the block.
+ # The actual count in this case is the absolute value of the count written.
+ async def read_array(self, writer_schema, decoder):
+ read_items = []
+ block_count = await decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ await decoder.read_long()
+ for _ in range(block_count):
+ read_items.append(await self.read_data(writer_schema.items, decoder))
+ block_count = await decoder.read_long()
+ return read_items
+
+ async def skip_array(self, writer_schema, decoder):
+ block_count = await decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = await decoder.read_long()
+ await decoder.skip(block_size)
+ else:
+ for _ in range(block_count):
+ await self.skip_data(writer_schema.items, decoder)
+ block_count = await decoder.read_long()
+
+ # Maps are encoded as a series of blocks.
+
+ # Each block consists of a long count value, followed by that many key/value pairs.
+ # A block with count zero indicates the end of the map. Each item is encoded per the map's value schema.
+
+ # If a block's count is negative, then the count is followed immediately by a long block size,
+ # indicating the number of bytes in the block.
+ # The actual count in this case is the absolute value of the count written.
+ async def read_map(self, writer_schema, decoder):
+ read_items = {}
+ block_count = await decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ await decoder.read_long()
+ for _ in range(block_count):
+ key = await decoder.read_utf8()
+ read_items[key] = await self.read_data(writer_schema.values, decoder)
+ block_count = await decoder.read_long()
+ return read_items
+
+ async def skip_map(self, writer_schema, decoder):
+ block_count = await decoder.read_long()
+ while block_count != 0:
+ if block_count < 0:
+ block_size = await decoder.read_long()
+ await decoder.skip(block_size)
+ else:
+ for _ in range(block_count):
+ await decoder.skip_utf8()
+ await self.skip_data(writer_schema.values, decoder)
+ block_count = await decoder.read_long()
+
+ # A union is encoded by first writing a long value indicating
+ # the zero-based position within the union of the schema of its value.
+ # The value is then encoded per the indicated schema within the union.
+ async def read_union(self, writer_schema, decoder):
+ # schema resolution
+ index_of_schema = int(await decoder.read_long())
+ if index_of_schema >= len(writer_schema.schemas):
+ fail_msg = (f"Can't access branch index {index_of_schema} "
+ f"for union with {len(writer_schema.schemas)} branches")
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ selected_writer_schema = writer_schema.schemas[index_of_schema]
+
+ # read data
+ return await self.read_data(selected_writer_schema, decoder)
+
+ async def skip_union(self, writer_schema, decoder):
+ index_of_schema = int(await decoder.read_long())
+ if index_of_schema >= len(writer_schema.schemas):
+ fail_msg = (f"Can't access branch index {index_of_schema} "
+ f"for union with {len(writer_schema.schemas)} branches")
+ raise SchemaResolutionException(fail_msg, writer_schema)
+ return await self.skip_data(writer_schema.schemas[index_of_schema], decoder)
+
+ # A record is encoded by encoding the values of its fields
+ # in the order that they are declared. In other words, a record
+ # is encoded as just the concatenation of the encodings of its fields.
+ # Field values are encoded per their schema.
+
+ # Schema Resolution:
+ # * the ordering of fields may be different: fields are matched by name.
+ # * schemas for fields with the same name in both records are resolved
+ # recursively.
+ # * if the writer's record contains a field with a name not present in the
+ # reader's record, the writer's value for that field is ignored.
+ # * if the reader's record schema has a field that contains a default value,
+ # and writer's schema does not have a field with the same name, then the
+ # reader should use the default value from its field.
+ # * if the reader's record schema has a field with no default value, and
+ # writer's schema does not have a field with the same name, then the
+ # field's value is unset.
+ async def read_record(self, writer_schema, decoder):
+ # schema resolution
+ read_record = {}
+ for field in writer_schema.fields:
+ field_val = await self.read_data(field.type, decoder)
+ read_record[field.name] = field_val
+ return read_record
+
+ async def skip_record(self, writer_schema, decoder):
+ for field in writer_schema.fields:
+ await self.skip_data(field.type, decoder)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile.py
new file mode 100644
index 00000000..757e0329
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile.py
@@ -0,0 +1,257 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-missing-return, docstring-missing-rtype
+
+"""Read/Write Avro File Object Containers."""
+
+import io
+import logging
+import sys
+import zlib
+
+from ..avro import avro_io
+from ..avro import schema
+
+PY3 = sys.version_info[0] == 3
+
+logger = logging.getLogger(__name__)
+
+# ------------------------------------------------------------------------------
+# Constants
+
+# Version of the container file:
+VERSION = 1
+
+if PY3:
+ MAGIC = b'Obj' + bytes([VERSION])
+ MAGIC_SIZE = len(MAGIC)
+else:
+ MAGIC = 'Obj' + chr(VERSION)
+ MAGIC_SIZE = len(MAGIC)
+
+# Size of the synchronization marker, in number of bytes:
+SYNC_SIZE = 16
+
+# Schema of the container header:
+META_SCHEMA = schema.parse("""
+{
+ "type": "record", "name": "org.apache.avro.file.Header",
+ "fields": [{
+ "name": "magic",
+ "type": {"type": "fixed", "name": "magic", "size": %(magic_size)d}
+ }, {
+ "name": "meta",
+ "type": {"type": "map", "values": "bytes"}
+ }, {
+ "name": "sync",
+ "type": {"type": "fixed", "name": "sync", "size": %(sync_size)d}
+ }]
+}
+""" % {
+ 'magic_size': MAGIC_SIZE,
+ 'sync_size': SYNC_SIZE,
+})
+
+# Codecs supported by container files:
+VALID_CODECS = frozenset(['null', 'deflate'])
+
+# Metadata key associated to the schema:
+SCHEMA_KEY = "avro.schema"
+
+
+# ------------------------------------------------------------------------------
+# Exceptions
+
+
+class DataFileException(schema.AvroException):
+ """Problem reading or writing file object containers."""
+
+# ------------------------------------------------------------------------------
+
+
+class DataFileReader(object): # pylint: disable=too-many-instance-attributes
+ """Read files written by DataFileWriter."""
+
+ def __init__(self, reader, datum_reader, **kwargs):
+ """Initializes a new data file reader.
+
+ Args:
+ reader: Open file to read from.
+ datum_reader: Avro datum reader.
+ """
+ self._reader = reader
+ self._raw_decoder = avro_io.BinaryDecoder(reader)
+ self._header_reader = kwargs.pop('header_reader', None)
+ self._header_decoder = None if self._header_reader is None else avro_io.BinaryDecoder(self._header_reader)
+ self._datum_decoder = None # Maybe reset at every block.
+ self._datum_reader = datum_reader
+
+ # In case self._reader only has partial content(without header).
+ # seek(0, 0) to make sure read the (partial)content from beginning.
+ self._reader.seek(0, 0)
+
+ # read the header: magic, meta, sync
+ self._read_header()
+
+ # ensure codec is valid
+ avro_codec_raw = self.get_meta('avro.codec')
+ if avro_codec_raw is None:
+ self.codec = "null"
+ else:
+ self.codec = avro_codec_raw.decode('utf-8')
+ if self.codec not in VALID_CODECS:
+ raise DataFileException(f"Unknown codec: {self.codec}.")
+
+ # get ready to read
+ self._block_count = 0
+
+ # object_position is to support reading from current position in the future read,
+ # no need to downloading from the beginning of avro.
+ if hasattr(self._reader, 'object_position'):
+ self.reader.track_object_position()
+
+ self._cur_object_index = 0
+ # header_reader indicates reader only has partial content. The reader doesn't have block header,
+ # so we read use the block count stored last time.
+ # Also ChangeFeed only has codec==null, so use _raw_decoder is good.
+ if self._header_reader is not None:
+ self._datum_decoder = self._raw_decoder
+
+ self.datum_reader.writer_schema = (
+ schema.parse(self.get_meta(SCHEMA_KEY).decode('utf-8')))
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, data_type, value, traceback):
+ # Perform a close if there's no exception
+ if data_type is None:
+ self.close()
+
+ def __iter__(self):
+ return self
+
+ # read-only properties
+ @property
+ def reader(self):
+ return self._reader
+
+ @property
+ def raw_decoder(self):
+ return self._raw_decoder
+
+ @property
+ def datum_decoder(self):
+ return self._datum_decoder
+
+ @property
+ def datum_reader(self):
+ return self._datum_reader
+
+ @property
+ def sync_marker(self):
+ return self._sync_marker
+
+ @property
+ def meta(self):
+ return self._meta
+
+ # read/write properties
+ @property
+ def block_count(self):
+ return self._block_count
+
+ def get_meta(self, key):
+ """Reports the value of a given metadata key.
+
+ :param str key: Metadata key to report the value of.
+ :returns: Value associated to the metadata key, as bytes.
+ :rtype: bytes
+ """
+ return self._meta.get(key)
+
+ def _read_header(self):
+ header_reader = self._header_reader if self._header_reader else self._reader
+ header_decoder = self._header_decoder if self._header_decoder else self._raw_decoder
+
+ # seek to the beginning of the file to get magic block
+ header_reader.seek(0, 0)
+
+ # read header into a dict
+ header = self.datum_reader.read_data(META_SCHEMA, header_decoder)
+
+ # check magic number
+ if header.get('magic') != MAGIC:
+ fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}."
+ raise schema.AvroException(fail_msg)
+
+ # set metadata
+ self._meta = header['meta']
+
+ # set sync marker
+ self._sync_marker = header['sync']
+
+ def _read_block_header(self):
+ self._block_count = self.raw_decoder.read_long()
+ if self.codec == "null":
+ # Skip a long; we don't need to use the length.
+ self.raw_decoder.skip_long()
+ self._datum_decoder = self._raw_decoder
+ elif self.codec == 'deflate':
+ # Compressed data is stored as (length, data), which
+ # corresponds to how the "bytes" type is encoded.
+ data = self.raw_decoder.read_bytes()
+ # -15 is the log of the window size; negative indicates
+ # "raw" (no zlib headers) decompression. See zlib.h.
+ uncompressed = zlib.decompress(data, -15)
+ self._datum_decoder = avro_io.BinaryDecoder(io.BytesIO(uncompressed))
+ else:
+ raise DataFileException(f"Unknown codec: {self.codec!r}")
+
+ def _skip_sync(self):
+ """
+ Read the length of the sync marker; if it matches the sync marker,
+ return True. Otherwise, seek back to where we started and return False.
+ """
+ proposed_sync_marker = self.reader.read(SYNC_SIZE)
+ if SYNC_SIZE > 0 and not proposed_sync_marker:
+ raise StopIteration
+ if proposed_sync_marker != self.sync_marker:
+ self.reader.seek(-SYNC_SIZE, 1)
+
+ def __next__(self):
+ """Return the next datum in the file."""
+ if self.block_count == 0:
+ self._skip_sync()
+
+ # object_position is to support reading from current position in the future read,
+ # no need to downloading from the beginning of avro file with this attr.
+ if hasattr(self._reader, 'object_position'):
+ self.reader.track_object_position()
+ self._cur_object_index = 0
+
+ self._read_block_header()
+
+ datum = self.datum_reader.read(self.datum_decoder)
+ self._block_count -= 1
+ self._cur_object_index += 1
+
+ # object_position is to support reading from current position in the future read,
+ # This will track the index of the next item to be read.
+ # This will also track the offset before the next sync marker.
+ if hasattr(self._reader, 'object_position'):
+ if self.block_count == 0:
+ # the next event to be read is at index 0 in the new chunk of blocks,
+ self.reader.track_object_position()
+ self.reader.set_object_index(0)
+ else:
+ self.reader.set_object_index(self._cur_object_index)
+
+ return datum
+
+ def close(self):
+ """Close this reader."""
+ self.reader.close()
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile_async.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile_async.py
new file mode 100644
index 00000000..85dc5cb5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/datafile_async.py
@@ -0,0 +1,210 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-missing-return, docstring-missing-rtype
+
+"""Read/Write Avro File Object Containers."""
+
+import logging
+import sys
+
+from ..avro import avro_io_async
+from ..avro import schema
+from .datafile import DataFileException
+from .datafile import MAGIC, SYNC_SIZE, META_SCHEMA, SCHEMA_KEY
+
+
+PY3 = sys.version_info[0] == 3
+
+logger = logging.getLogger(__name__)
+
+# ------------------------------------------------------------------------------
+# Constants
+
+# Codecs supported by container files:
+VALID_CODECS = frozenset(['null'])
+
+
+class AsyncDataFileReader(object): # pylint: disable=too-many-instance-attributes
+ """Read files written by DataFileWriter."""
+
+ def __init__(self, reader, datum_reader, **kwargs):
+ """Initializes a new data file reader.
+
+ Args:
+ reader: Open file to read from.
+ datum_reader: Avro datum reader.
+ """
+ self._reader = reader
+ self._raw_decoder = avro_io_async.AsyncBinaryDecoder(reader)
+ self._header_reader = kwargs.pop('header_reader', None)
+ self._header_decoder = None if self._header_reader is None else \
+ avro_io_async.AsyncBinaryDecoder(self._header_reader)
+ self._datum_decoder = None # Maybe reset at every block.
+ self._datum_reader = datum_reader
+ self.codec = "null"
+ self._block_count = 0
+ self._cur_object_index = 0
+ self._meta = None
+ self._sync_marker = None
+
+ async def init(self):
+ # In case self._reader only has partial content(without header).
+ # seek(0, 0) to make sure read the (partial)content from beginning.
+ await self._reader.seek(0, 0)
+
+ # read the header: magic, meta, sync
+ await self._read_header()
+
+ # ensure codec is valid
+ avro_codec_raw = self.get_meta('avro.codec')
+ if avro_codec_raw is None:
+ self.codec = "null"
+ else:
+ self.codec = avro_codec_raw.decode('utf-8')
+ if self.codec not in VALID_CODECS:
+ raise DataFileException(f"Unknown codec: {self.codec}.")
+
+ # get ready to read
+ self._block_count = 0
+
+ # object_position is to support reading from current position in the future read,
+ # no need to downloading from the beginning of avro.
+ if hasattr(self._reader, 'object_position'):
+ self.reader.track_object_position()
+
+ # header_reader indicates reader only has partial content. The reader doesn't have block header,
+ # so we read use the block count stored last time.
+ # Also ChangeFeed only has codec==null, so use _raw_decoder is good.
+ if self._header_reader is not None:
+ self._datum_decoder = self._raw_decoder
+ self.datum_reader.writer_schema = (
+ schema.parse(self.get_meta(SCHEMA_KEY).decode('utf-8')))
+ return self
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, data_type, value, traceback):
+ # Perform a close if there's no exception
+ if data_type is None:
+ self.close()
+
+ def __aiter__(self):
+ return self
+
+ # read-only properties
+ @property
+ def reader(self):
+ return self._reader
+
+ @property
+ def raw_decoder(self):
+ return self._raw_decoder
+
+ @property
+ def datum_decoder(self):
+ return self._datum_decoder
+
+ @property
+ def datum_reader(self):
+ return self._datum_reader
+
+ @property
+ def sync_marker(self):
+ return self._sync_marker
+
+ @property
+ def meta(self):
+ return self._meta
+
+ # read/write properties
+ @property
+ def block_count(self):
+ return self._block_count
+
+ def get_meta(self, key):
+ """Reports the value of a given metadata key.
+
+ :param str key: Metadata key to report the value of.
+ :returns: Value associated to the metadata key, as bytes.
+ :rtype: bytes
+ """
+ return self._meta.get(key)
+
+ async def _read_header(self):
+ header_reader = self._header_reader if self._header_reader else self._reader
+ header_decoder = self._header_decoder if self._header_decoder else self._raw_decoder
+
+ # seek to the beginning of the file to get magic block
+ await header_reader.seek(0, 0)
+
+ # read header into a dict
+ header = await self.datum_reader.read_data(META_SCHEMA, header_decoder)
+
+ # check magic number
+ if header.get('magic') != MAGIC:
+ fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}."
+ raise schema.AvroException(fail_msg)
+
+ # set metadata
+ self._meta = header['meta']
+
+ # set sync marker
+ self._sync_marker = header['sync']
+
+ async def _read_block_header(self):
+ self._block_count = await self.raw_decoder.read_long()
+ if self.codec == "null":
+ # Skip a long; we don't need to use the length.
+ await self.raw_decoder.skip_long()
+ self._datum_decoder = self._raw_decoder
+ else:
+ raise DataFileException(f"Unknown codec: {self.codec!r}")
+
+ async def _skip_sync(self):
+ """
+ Read the length of the sync marker; if it matches the sync marker,
+ return True. Otherwise, seek back to where we started and return False.
+ """
+ proposed_sync_marker = await self.reader.read(SYNC_SIZE)
+ if SYNC_SIZE > 0 and not proposed_sync_marker:
+ raise StopAsyncIteration
+ if proposed_sync_marker != self.sync_marker:
+ await self.reader.seek(-SYNC_SIZE, 1)
+
+ async def __anext__(self):
+ """Return the next datum in the file."""
+ if self.block_count == 0:
+ await self._skip_sync()
+
+ # object_position is to support reading from current position in the future read,
+ # no need to downloading from the beginning of avro file with this attr.
+ if hasattr(self._reader, 'object_position'):
+ await self.reader.track_object_position()
+ self._cur_object_index = 0
+
+ await self._read_block_header()
+
+ datum = await self.datum_reader.read(self.datum_decoder)
+ self._block_count -= 1
+ self._cur_object_index += 1
+
+ # object_position is to support reading from current position in the future read,
+ # This will track the index of the next item to be read.
+ # This will also track the offset before the next sync marker.
+ if hasattr(self._reader, 'object_position'):
+ if self.block_count == 0:
+ # the next event to be read is at index 0 in the new chunk of blocks,
+ await self.reader.track_object_position()
+ await self.reader.set_object_index(0)
+ else:
+ await self.reader.set_object_index(self._cur_object_index)
+
+ return datum
+
+ def close(self):
+ """Close this reader."""
+ self.reader.close()
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/schema.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/schema.py
new file mode 100644
index 00000000..d5484abc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/avro/schema.py
@@ -0,0 +1,1178 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-missing-return, docstring-missing-rtype, too-many-lines
+
+"""Representation of Avro schemas.
+
+A schema may be one of:
+ - A record, mapping field names to field value data;
+ - An error, equivalent to a record;
+ - An enum, containing one of a small set of symbols;
+ - An array of values, all of the same schema;
+ - A map containing string/value pairs, each of a declared schema;
+ - A union of other schemas;
+ - A fixed sized binary object;
+ - A unicode string;
+ - A sequence of bytes;
+ - A 32-bit signed int;
+ - A 64-bit signed long;
+ - A 32-bit floating-point float;
+ - A 64-bit floating-point double;
+ - A boolean;
+ - Null.
+"""
+
+import abc
+import json
+import logging
+import re
+logger = logging.getLogger(__name__)
+
+# ------------------------------------------------------------------------------
+# Constants
+
+# Log level more verbose than DEBUG=10, INFO=20, etc.
+DEBUG_VERBOSE = 5
+
+NULL = 'null'
+BOOLEAN = 'boolean'
+STRING = 'string'
+BYTES = 'bytes'
+INT = 'int'
+LONG = 'long'
+FLOAT = 'float'
+DOUBLE = 'double'
+FIXED = 'fixed'
+ENUM = 'enum'
+RECORD = 'record'
+ERROR = 'error'
+ARRAY = 'array'
+MAP = 'map'
+UNION = 'union'
+
+# Request and error unions are part of Avro protocols:
+REQUEST = 'request'
+ERROR_UNION = 'error_union'
+
+PRIMITIVE_TYPES = frozenset([
+ NULL,
+ BOOLEAN,
+ STRING,
+ BYTES,
+ INT,
+ LONG,
+ FLOAT,
+ DOUBLE,
+])
+
+NAMED_TYPES = frozenset([
+ FIXED,
+ ENUM,
+ RECORD,
+ ERROR,
+])
+
+VALID_TYPES = frozenset.union(
+ PRIMITIVE_TYPES,
+ NAMED_TYPES,
+ [
+ ARRAY,
+ MAP,
+ UNION,
+ REQUEST,
+ ERROR_UNION,
+ ],
+)
+
+SCHEMA_RESERVED_PROPS = frozenset([
+ 'type',
+ 'name',
+ 'namespace',
+ 'fields', # Record
+ 'items', # Array
+ 'size', # Fixed
+ 'symbols', # Enum
+ 'values', # Map
+ 'doc',
+])
+
+FIELD_RESERVED_PROPS = frozenset([
+ 'default',
+ 'name',
+ 'doc',
+ 'order',
+ 'type',
+])
+
+VALID_FIELD_SORT_ORDERS = frozenset([
+ 'ascending',
+ 'descending',
+ 'ignore',
+])
+
+
+# ------------------------------------------------------------------------------
+# Exceptions
+
+
+class Error(Exception):
+ """Base class for errors in this module."""
+
+
+class AvroException(Error):
+ """Generic Avro schema error."""
+
+
+class SchemaParseException(AvroException):
+ """Error while parsing a JSON schema descriptor."""
+
+
+class Schema(metaclass=abc.ABCMeta):
+ """Abstract base class for all Schema classes."""
+
+ def __init__(self, data_type, other_props=None):
+ """Initializes a new schema object.
+
+ Args:
+ data_type: Type of the schema to initialize.
+ other_props: Optional dictionary of additional properties.
+ """
+ if data_type not in VALID_TYPES:
+ raise SchemaParseException(f'{data_type!r} is not a valid Avro type.')
+
+ # All properties of this schema, as a map: property name -> property value
+ self._props = {}
+
+ self._props['type'] = data_type
+ self._type = data_type
+
+ if other_props:
+ self._props.update(other_props)
+
+ @property
+ def namespace(self):
+ """Returns: the namespace this schema belongs to, if any, or None."""
+ return self._props.get('namespace', None)
+
+ @property
+ def type(self):
+ """Returns: the type of this schema."""
+ return self._type
+
+ @property
+ def doc(self):
+ """Returns: the documentation associated to this schema, if any, or None."""
+ return self._props.get('doc', None)
+
+ @property
+ def props(self):
+ """Reports all the properties of this schema.
+
+ Includes all properties, reserved and non reserved.
+ JSON properties of this schema are directly generated from this dict.
+
+ Returns:
+ A dictionary of properties associated to this schema.
+ """
+ return self._props
+
+ @property
+ def other_props(self):
+ """Returns: the dictionary of non-reserved properties."""
+ return dict(filter_keys_out(items=self._props, keys=SCHEMA_RESERVED_PROPS))
+
+ def __str__(self):
+ """Returns: the JSON representation of this schema."""
+ return json.dumps(self.to_json(names=None))
+
+ # Converts the schema object into its AVRO specification representation.
+
+ # Schema types that have names (records, enums, and fixed) must be aware of not
+ # re-defining schemas that are already listed in the parameter names.
+ @abc.abstractmethod
+ def to_json(self, names):
+ ...
+
+
+# ------------------------------------------------------------------------------
+
+
+_RE_NAME = re.compile(r'[A-Za-z_][A-Za-z0-9_]*')
+
+_RE_FULL_NAME = re.compile(
+ r'^'
+ r'[.]?(?:[A-Za-z_][A-Za-z0-9_]*[.])*' # optional namespace
+ r'([A-Za-z_][A-Za-z0-9_]*)' # name
+ r'$'
+)
+
+
+class Name(object):
+ """Representation of an Avro name."""
+
+ def __init__(self, name, namespace=None):
+ """Parses an Avro name.
+
+ Args:
+ name: Avro name to parse (relative or absolute).
+ namespace: Optional explicit namespace if the name is relative.
+ """
+ # Normalize: namespace is always defined as a string, possibly empty.
+ if namespace is None:
+ namespace = ''
+
+ if '.' in name:
+ # name is absolute, namespace is ignored:
+ self._fullname = name
+
+ match = _RE_FULL_NAME.match(self._fullname)
+ if match is None:
+ raise SchemaParseException(
+ f'Invalid absolute schema name: {self._fullname!r}.')
+
+ self._name = match.group(1)
+ self._namespace = self._fullname[:-(len(self._name) + 1)]
+
+ else:
+ # name is relative, combine with explicit namespace:
+ self._name = name
+ self._namespace = namespace
+ self._fullname = (self._name
+ if (not self._namespace) else
+ f'{self._namespace}.{self._name}')
+
+ # Validate the fullname:
+ if _RE_FULL_NAME.match(self._fullname) is None:
+ raise SchemaParseException(f"Invalid schema name {self._fullname!r} inferred from "
+ f"name {self._name!r} and namespace {self._namespace!r}.")
+
+ def __eq__(self, other):
+ if not isinstance(other, Name):
+ return NotImplemented
+ return self.fullname == other.fullname
+
+ @property
+ def simple_name(self):
+ """Returns: the simple name part of this name."""
+ return self._name
+
+ @property
+ def namespace(self):
+ """Returns: this name's namespace, possible the empty string."""
+ return self._namespace
+
+ @property
+ def fullname(self):
+ """Returns: the full name."""
+ return self._fullname
+
+
+# ------------------------------------------------------------------------------
+
+
+class Names(object):
+ """Tracks Avro named schemas and default namespace during parsing."""
+
+ def __init__(self, default_namespace=None, names=None):
+ """Initializes a new name tracker.
+
+ Args:
+ default_namespace: Optional default namespace.
+ names: Optional initial mapping of known named schemas.
+ """
+ if names is None:
+ names = {}
+ self._names = names
+ self._default_namespace = default_namespace
+
+ @property
+ def names(self):
+ """Returns: the mapping of known named schemas."""
+ return self._names
+
+ @property
+ def default_namespace(self):
+ """Returns: the default namespace, if any, or None."""
+ return self._default_namespace
+
+ def new_with_default_namespace(self, namespace):
+ """Creates a new name tracker from this tracker, but with a new default ns.
+
+ :param Any namespace: New default namespace to use.
+ :returns: New name tracker with the specified default namespace.
+ :rtype: Names
+ """
+ return Names(names=self._names, default_namespace=namespace)
+
+ def get_name(self, name, namespace=None):
+ """Resolves the Avro name according to this name tracker's state.
+
+ :param Any name: Name to resolve (absolute or relative).
+ :param Optional[Any] namespace: Optional explicit namespace.
+ :returns: The specified name, resolved according to this tracker.
+ :rtype: Name
+ """
+ if namespace is None:
+ namespace = self._default_namespace
+ return Name(name=name, namespace=namespace)
+
+ def get_schema(self, name, namespace=None):
+ """Resolves an Avro schema by name.
+
+ :param Any name: Name (absolute or relative) of the Avro schema to look up.
+ :param Optional[Any] namespace: Optional explicit namespace.
+ :returns: The schema with the specified name, if any, or None
+ :rtype: Union[Any, None]
+ """
+ avro_name = self.get_name(name=name, namespace=namespace)
+ return self._names.get(avro_name.fullname, None)
+
+ # Given a properties, return properties with namespace removed if it matches the own default namespace
+ def prune_namespace(self, properties):
+ if self.default_namespace is None:
+ # I have no default -- no change
+ return properties
+ if 'namespace' not in properties:
+ # he has no namespace - no change
+ return properties
+ if properties['namespace'] != self.default_namespace:
+ # we're different - leave his stuff alone
+ return properties
+ # we each have a namespace and it's redundant. delete his.
+ prunable = properties.copy()
+ del prunable['namespace']
+ return prunable
+
+ def register(self, schema):
+ """Registers a new named schema in this tracker.
+
+ :param Any schema: Named Avro schema to register in this tracker.
+ """
+ if schema.fullname in VALID_TYPES:
+ raise SchemaParseException(
+ f'{schema.fullname} is a reserved type name.')
+ if schema.fullname in self.names:
+ raise SchemaParseException(
+ f'Avro name {schema.fullname!r} already exists.')
+
+ logger.log(DEBUG_VERBOSE, 'Register new name for %r', schema.fullname)
+ self._names[schema.fullname] = schema
+
+
+# ------------------------------------------------------------------------------
+
+
+class NamedSchema(Schema):
+ """Abstract base class for named schemas.
+
+ Named schemas are enumerated in NAMED_TYPES.
+ """
+
+ def __init__(
+ self,
+ data_type,
+ name=None,
+ namespace=None,
+ names=None,
+ other_props=None,
+ ):
+ """Initializes a new named schema object.
+
+ Args:
+ data_type: Type of the named schema.
+ name: Name (absolute or relative) of the schema.
+ namespace: Optional explicit namespace if name is relative.
+ names: Tracker to resolve and register Avro names.
+ other_props: Optional map of additional properties of the schema.
+ """
+ assert (data_type in NAMED_TYPES), (f'Invalid named type: {data_type!r}')
+ self._avro_name = names.get_name(name=name, namespace=namespace)
+
+ super(NamedSchema, self).__init__(data_type, other_props)
+
+ names.register(self)
+
+ self._props['name'] = self.name
+ if self.namespace:
+ self._props['namespace'] = self.namespace
+
+ @property
+ def avro_name(self):
+ """Returns: the Name object describing this schema's name."""
+ return self._avro_name
+
+ @property
+ def name(self):
+ return self._avro_name.simple_name
+
+ @property
+ def namespace(self):
+ return self._avro_name.namespace
+
+ @property
+ def fullname(self):
+ return self._avro_name.fullname
+
+ def name_ref(self, names):
+ """Reports this schema name relative to the specified name tracker.
+
+ :param Any names: Avro name tracker to relativize this schema name against.
+ :returns: This schema name, relativized against the specified name tracker.
+ :rtype: Any
+ """
+ if self.namespace == names.default_namespace:
+ return self.name
+ return self.fullname
+
+ # Converts the schema object into its AVRO specification representation.
+
+ # Schema types that have names (records, enums, and fixed) must be aware
+ # of not re-defining schemas that are already listed in the parameter names.
+ @abc.abstractmethod
+ def to_json(self, names):
+ ...
+
+# ------------------------------------------------------------------------------
+
+
+_NO_DEFAULT = object()
+
+
+class Field(object):
+ """Representation of the schema of a field in a record."""
+
+ def __init__(
+ self,
+ data_type,
+ name,
+ index,
+ has_default,
+ default=_NO_DEFAULT,
+ order=None,
+ doc=None,
+ other_props=None
+ ):
+ """Initializes a new Field object.
+
+ Args:
+ data_type: Avro schema of the field.
+ name: Name of the field.
+ index: 0-based position of the field.
+ has_default:
+ default:
+ order:
+ doc:
+ other_props:
+ """
+ if (not isinstance(name, str)) or (not name):
+ raise SchemaParseException(f'Invalid record field name: {name!r}.')
+ if (order is not None) and (order not in VALID_FIELD_SORT_ORDERS):
+ raise SchemaParseException(f'Invalid record field order: {order!r}.')
+
+ # All properties of this record field:
+ self._props = {}
+
+ self._has_default = has_default
+ if other_props:
+ self._props.update(other_props)
+
+ self._index = index
+ self._type = self._props['type'] = data_type
+ self._name = self._props['name'] = name
+
+ if has_default:
+ self._props['default'] = default
+
+ if order is not None:
+ self._props['order'] = order
+
+ if doc is not None:
+ self._props['doc'] = doc
+
+ @property
+ def type(self):
+ """Returns: the schema of this field."""
+ return self._type
+
+ @property
+ def name(self):
+ """Returns: this field name."""
+ return self._name
+
+ @property
+ def index(self):
+ """Returns: the 0-based index of this field in the record."""
+ return self._index
+
+ @property
+ def default(self):
+ return self._props['default']
+
+ @property
+ def has_default(self):
+ return self._has_default
+
+ @property
+ def order(self):
+ return self._props.get('order', None)
+
+ @property
+ def doc(self):
+ return self._props.get('doc', None)
+
+ @property
+ def props(self):
+ return self._props
+
+ @property
+ def other_props(self):
+ return filter_keys_out(items=self._props, keys=FIELD_RESERVED_PROPS)
+
+ def __str__(self):
+ return json.dumps(self.to_json())
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ to_dump['type'] = self.type.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+
+# ------------------------------------------------------------------------------
+# Primitive Types
+
+
+class PrimitiveSchema(Schema):
+ """Schema of a primitive Avro type.
+
+ Valid primitive types are defined in PRIMITIVE_TYPES.
+ """
+
+ def __init__(self, data_type, other_props=None):
+ """Initializes a new schema object for the specified primitive type.
+
+ Args:
+ data_type: Type of the schema to construct. Must be primitive.
+ """
+ if data_type not in PRIMITIVE_TYPES:
+ raise AvroException(f'{data_type!r} is not a valid primitive type.')
+ super(PrimitiveSchema, self).__init__(data_type, other_props=other_props)
+
+ @property
+ def name(self):
+ """Returns: the simple name of this schema."""
+ # The name of a primitive type is the type itself.
+ return self.type
+
+ @property
+ def fullname(self):
+ """Returns: the fully qualified name of this schema."""
+ # The full name is the simple name for primitive schema.
+ return self.name
+
+ def to_json(self, names=None):
+ if len(self.props) == 1:
+ return self.fullname
+ return self.props
+
+ def __eq__(self, that):
+ return self.props == that.props
+
+
+# ------------------------------------------------------------------------------
+# Complex Types (non-recursive)
+
+
+class FixedSchema(NamedSchema):
+ def __init__(
+ self,
+ name,
+ namespace,
+ size,
+ names=None,
+ other_props=None,
+ ):
+ # Ensure valid ctor args
+ if not isinstance(size, int):
+ fail_msg = 'Fixed Schema requires a valid integer for size property.'
+ raise AvroException(fail_msg)
+
+ super(FixedSchema, self).__init__(
+ data_type=FIXED,
+ name=name,
+ namespace=namespace,
+ names=names,
+ other_props=other_props,
+ )
+ self._props['size'] = size
+
+ @property
+ def size(self):
+ """Returns: the size of this fixed schema, in bytes."""
+ return self._props['size']
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ names.names[self.fullname] = self
+ return names.prune_namespace(self.props)
+
+ def __eq__(self, that):
+ return self.props == that.props
+
+
+# ------------------------------------------------------------------------------
+
+
+class EnumSchema(NamedSchema):
+ def __init__(
+ self,
+ name,
+ namespace,
+ symbols,
+ names=None,
+ doc=None,
+ other_props=None,
+ ):
+ """Initializes a new enumeration schema object.
+
+ Args:
+ name: Simple name of this enumeration.
+ namespace: Optional namespace.
+ symbols: Ordered list of symbols defined in this enumeration.
+ names:
+ doc:
+ other_props:
+ """
+ symbols = tuple(symbols)
+ symbol_set = frozenset(symbols)
+ if (len(symbol_set) != len(symbols)
+ or not all(map(lambda symbol: isinstance(symbol, str), symbols))):
+ raise AvroException(
+ f'Invalid symbols for enum schema: {symbols!r}.')
+
+ super(EnumSchema, self).__init__(
+ data_type=ENUM,
+ name=name,
+ namespace=namespace,
+ names=names,
+ other_props=other_props,
+ )
+
+ self._props['symbols'] = symbols
+ if doc is not None:
+ self._props['doc'] = doc
+
+ @property
+ def symbols(self):
+ """Returns: the symbols defined in this enum."""
+ return self._props['symbols']
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ names.names[self.fullname] = self
+ return names.prune_namespace(self.props)
+
+ def __eq__(self, that):
+ return self.props == that.props
+
+
+# ------------------------------------------------------------------------------
+# Complex Types (recursive)
+
+
+class ArraySchema(Schema):
+ """Schema of an array."""
+
+ def __init__(self, items, other_props=None):
+ """Initializes a new array schema object.
+
+ Args:
+ items: Avro schema of the array items.
+ other_props:
+ """
+ super(ArraySchema, self).__init__(
+ data_type=ARRAY,
+ other_props=other_props,
+ )
+ self._items_schema = items
+ self._props['items'] = items
+
+ @property
+ def items(self):
+ """Returns: the schema of the items in this array."""
+ return self._items_schema
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ item_schema = self.items
+ to_dump['items'] = item_schema.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+
+# ------------------------------------------------------------------------------
+
+
+class MapSchema(Schema):
+ """Schema of a map."""
+
+ def __init__(self, values, other_props=None):
+ """Initializes a new map schema object.
+
+ Args:
+ values: Avro schema of the map values.
+ other_props:
+ """
+ super(MapSchema, self).__init__(
+ data_type=MAP,
+ other_props=other_props,
+ )
+ self._values_schema = values
+ self._props['values'] = values
+
+ @property
+ def values(self):
+ """Returns: the schema of the values in this map."""
+ return self._values_schema
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = self.props.copy()
+ to_dump['values'] = self.values.to_json(names)
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+
+# ------------------------------------------------------------------------------
+
+
+class UnionSchema(Schema):
+ """Schema of a union."""
+
+ def __init__(self, schemas):
+ """Initializes a new union schema object.
+
+ Args:
+ schemas: Ordered collection of schema branches in the union.
+ """
+ super(UnionSchema, self).__init__(data_type=UNION)
+ self._schemas = tuple(schemas)
+
+ # Validate the schema branches:
+
+ # All named schema names are unique:
+ named_branches = tuple(
+ filter(lambda schema: schema.type in NAMED_TYPES, self._schemas))
+ unique_names = frozenset(map(lambda schema: schema.fullname, named_branches))
+ if len(unique_names) != len(named_branches):
+ schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas))
+ raise AvroException(f'Invalid union branches with duplicate schema name:{schemas}')
+
+ # Types are unique within unnamed schemas, and union is not allowed:
+ unnamed_branches = tuple(
+ filter(lambda schema: schema.type not in NAMED_TYPES, self._schemas))
+ unique_types = frozenset(map(lambda schema: schema.type, unnamed_branches))
+ if UNION in unique_types:
+ schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas))
+ raise AvroException(f'Invalid union branches contain other unions:{schemas}')
+ if len(unique_types) != len(unnamed_branches):
+ schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas))
+ raise AvroException(f'Invalid union branches with duplicate type:{schemas}')
+
+ @property
+ def schemas(self):
+ """Returns: the ordered list of schema branches in the union."""
+ return self._schemas
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = []
+ for schema in self.schemas:
+ to_dump.append(schema.to_json(names))
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+
+# ------------------------------------------------------------------------------
+
+
+class ErrorUnionSchema(UnionSchema):
+ """Schema representing the declared errors of a protocol message."""
+
+ def __init__(self, schemas):
+ """Initializes an error-union schema.
+
+ Args:
+ schema: collection of error schema.
+ """
+ # Prepend "string" to handle system errors
+ schemas = [PrimitiveSchema(data_type=STRING)] + list(schemas)
+ super(ErrorUnionSchema, self).__init__(schemas=schemas)
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ to_dump = []
+ for schema in self.schemas:
+ # Don't print the system error schema
+ if schema.type == STRING:
+ continue
+ to_dump.append(schema.to_json(names))
+ return to_dump
+
+
+# ------------------------------------------------------------------------------
+
+
+class RecordSchema(NamedSchema):
+ """Schema of a record."""
+
+ @staticmethod
+ def _make_field(index, field_desc, names):
+ """Builds field schemas from a list of field JSON descriptors.
+
+ :param int index: 0-based index of the field in the record.
+ :param Any field_desc: JSON descriptors of a record field.
+ :param Any names: The names for this schema.
+ :returns: The field schema.
+ :rtype: Field
+ """
+ field_schema = schema_from_json_data(
+ json_data=field_desc['type'],
+ names=names,
+ )
+ other_props = (
+ dict(filter_keys_out(items=field_desc, keys=FIELD_RESERVED_PROPS)))
+ return Field(
+ data_type=field_schema,
+ name=field_desc['name'],
+ index=index,
+ has_default=('default' in field_desc),
+ default=field_desc.get('default', _NO_DEFAULT),
+ order=field_desc.get('order', None),
+ doc=field_desc.get('doc', None),
+ other_props=other_props,
+ )
+
+ @staticmethod
+ def make_field_list(field_desc_list, names):
+ """Builds field schemas from a list of field JSON descriptors.
+ Guarantees field name unicity.
+
+ :param Any field_desc_list: Collection of field JSON descriptors.
+ :param Any names: The names for this schema.
+ :returns: Field schemas.
+ :rtype: Field
+ """
+ for index, field_desc in enumerate(field_desc_list):
+ yield RecordSchema._make_field(index, field_desc, names)
+
+ @staticmethod
+ def _make_field_map(fields):
+ """Builds the field map.
+ Guarantees field name unicity.
+
+ :param Any fields: Iterable of field schema.
+ :returns: A map of field schemas, indexed by name.
+ :rtype: Dict[Any, Any]
+ """
+ field_map = {}
+ for field in fields:
+ if field.name in field_map:
+ raise SchemaParseException(
+ f'Duplicate record field name {field.name!r}.')
+ field_map[field.name] = field
+ return field_map
+
+ def __init__(
+ self,
+ name,
+ namespace,
+ fields=None,
+ make_fields=None,
+ names=None,
+ record_type=RECORD,
+ doc=None,
+ other_props=None
+ ):
+ """Initializes a new record schema object.
+
+ Args:
+ name: Name of the record (absolute or relative).
+ namespace: Optional namespace the record belongs to, if name is relative.
+ fields: collection of fields to add to this record.
+ Exactly one of fields or make_fields must be specified.
+ make_fields: function creating the fields that belong to the record.
+ The function signature is: make_fields(names) -> ordered field list.
+ Exactly one of fields or make_fields must be specified.
+ names:
+ record_type: Type of the record: one of RECORD, ERROR or REQUEST.
+ Protocol requests are not named.
+ doc:
+ other_props:
+ """
+ if record_type == REQUEST:
+ # Protocol requests are not named:
+ super(RecordSchema, self).__init__(
+ data_type=REQUEST,
+ other_props=other_props,
+ )
+ elif record_type in [RECORD, ERROR]:
+ # Register this record name in the tracker:
+ super(RecordSchema, self).__init__(
+ data_type=record_type,
+ name=name,
+ namespace=namespace,
+ names=names,
+ other_props=other_props,
+ )
+ else:
+ raise SchemaParseException(
+ f'Invalid record type: {record_type!r}.')
+
+ nested_names = []
+ if record_type in [RECORD, ERROR]:
+ avro_name = names.get_name(name=name, namespace=namespace)
+ nested_names = names.new_with_default_namespace(namespace=avro_name.namespace)
+ elif record_type == REQUEST:
+ # Protocol request has no name: no need to change default namespace:
+ nested_names = names
+
+ if fields is None:
+ fields = make_fields(names=nested_names)
+ else:
+ assert make_fields is None
+ self._fields = tuple(fields)
+
+ self._field_map = RecordSchema._make_field_map(self._fields)
+
+ self._props['fields'] = fields
+ if doc is not None:
+ self._props['doc'] = doc
+
+ @property
+ def fields(self):
+ """Returns: the field schemas, as an ordered tuple."""
+ return self._fields
+
+ @property
+ def field_map(self):
+ """Returns: a read-only map of the field schemas index by field names."""
+ return self._field_map
+
+ def to_json(self, names=None):
+ if names is None:
+ names = Names()
+ # Request records don't have names
+ if self.type == REQUEST:
+ return [f.to_json(names) for f in self.fields]
+
+ if self.fullname in names.names:
+ return self.name_ref(names)
+ names.names[self.fullname] = self
+
+ to_dump = names.prune_namespace(self.props.copy())
+ to_dump['fields'] = [f.to_json(names) for f in self.fields]
+ return to_dump
+
+ def __eq__(self, that):
+ to_cmp = json.loads(str(self))
+ return to_cmp == json.loads(str(that))
+
+
+# ------------------------------------------------------------------------------
+# Module functions
+
+
+def filter_keys_out(items, keys):
+ """Filters a collection of (key, value) items.
+ Exclude any item whose key belongs to keys.
+
+ :param Dict[Any, Any] items: Dictionary of items to filter the keys out of.
+ :param Dict[Any, Any] keys: Dictionary of keys to filter the extracted keys against.
+ :returns: Filtered items.
+ :rtype: Tuple(Any, Any)
+ """
+ for key, value in items.items():
+ if key in keys:
+ continue
+ yield key, value
+
+
+# ------------------------------------------------------------------------------
+
+
+def _schema_from_json_string(json_string, names):
+ if json_string in PRIMITIVE_TYPES:
+ return PrimitiveSchema(data_type=json_string)
+
+ # Look for a known named schema:
+ schema = names.get_schema(name=json_string)
+ if schema is None:
+ raise SchemaParseException(f"Unknown named schema {json_string!r}, known names: {sorted(names.names)!r}.")
+ return schema
+
+
+def _schema_from_json_array(json_array, names):
+ def MakeSchema(desc):
+ return schema_from_json_data(json_data=desc, names=names)
+
+ return UnionSchema(map(MakeSchema, json_array))
+
+
+def _schema_from_json_object(json_object, names):
+ data_type = json_object.get('type')
+ if data_type is None:
+ raise SchemaParseException(
+ f'Avro schema JSON descriptor has no "type" property: {json_object!r}')
+
+ other_props = dict(
+ filter_keys_out(items=json_object, keys=SCHEMA_RESERVED_PROPS))
+
+ if data_type in PRIMITIVE_TYPES:
+ # FIXME should not ignore other properties
+ result = PrimitiveSchema(data_type, other_props=other_props)
+
+ elif data_type in NAMED_TYPES:
+ name = json_object.get('name')
+ namespace = json_object.get('namespace', names.default_namespace)
+ if data_type == FIXED:
+ size = json_object.get('size')
+ result = FixedSchema(name, namespace, size, names, other_props)
+ elif data_type == ENUM:
+ symbols = json_object.get('symbols')
+ doc = json_object.get('doc')
+ result = EnumSchema(name, namespace, symbols, names, doc, other_props)
+
+ elif data_type in [RECORD, ERROR]:
+ field_desc_list = json_object.get('fields', ())
+
+ def MakeFields(names):
+ return tuple(RecordSchema.make_field_list(field_desc_list, names))
+
+ result = RecordSchema(
+ name=name,
+ namespace=namespace,
+ make_fields=MakeFields,
+ names=names,
+ record_type=data_type,
+ doc=json_object.get('doc'),
+ other_props=other_props,
+ )
+ else:
+ raise ValueError(f'Internal error: unknown type {data_type!r}.')
+
+ elif data_type in VALID_TYPES:
+ # Unnamed, non-primitive Avro type:
+
+ if data_type == ARRAY:
+ items_desc = json_object.get('items')
+ if items_desc is None:
+ raise SchemaParseException(f'Invalid array schema descriptor with no "items" : {json_object!r}.')
+ result = ArraySchema(
+ items=schema_from_json_data(items_desc, names),
+ other_props=other_props,
+ )
+
+ elif data_type == MAP:
+ values_desc = json_object.get('values')
+ if values_desc is None:
+ raise SchemaParseException(f'Invalid map schema descriptor with no "values" : {json_object!r}.')
+ result = MapSchema(
+ values=schema_from_json_data(values_desc, names=names),
+ other_props=other_props,
+ )
+
+ elif data_type == ERROR_UNION:
+ error_desc_list = json_object.get('declared_errors')
+ assert error_desc_list is not None
+ error_schemas = map(
+ lambda desc: schema_from_json_data(desc, names=names),
+ error_desc_list)
+ result = ErrorUnionSchema(schemas=error_schemas)
+
+ else:
+ raise ValueError(f'Internal error: unknown type {data_type!r}.')
+ else:
+ raise SchemaParseException(f'Invalid JSON descriptor for an Avro schema: {json_object!r}')
+ return result
+
+
+# Parsers for the JSON data types:
+_JSONDataParserTypeMap = {
+ str: _schema_from_json_string,
+ list: _schema_from_json_array,
+ dict: _schema_from_json_object,
+}
+
+
+def schema_from_json_data(json_data, names=None):
+ """Builds an Avro Schema from its JSON descriptor.
+ Raises SchemaParseException if the descriptor is invalid.
+
+ :param Any json_data: JSON data representing the descriptor of the Avro schema.
+ :param Any names: Optional tracker for Avro named schemas.
+ :returns: The Avro schema parsed from the JSON descriptor.
+ :rtype: Any
+ """
+ if names is None:
+ names = Names()
+
+ # Select the appropriate parser based on the JSON data type:
+ parser = _JSONDataParserTypeMap.get(type(json_data))
+ if parser is None:
+ raise SchemaParseException(
+ f'Invalid JSON descriptor for an Avro schema: {json_data!r}.')
+ return parser(json_data, names=names)
+
+
+# ------------------------------------------------------------------------------
+
+
+def parse(json_string):
+ """Constructs a Schema from its JSON descriptor in text form.
+ Raises SchemaParseException if a JSON parsing error is met, or if the JSON descriptor is invalid.
+
+ :param str json_string: String representation of the JSON descriptor of the schema.
+ :returns: The parsed schema.
+ :rtype: Any
+ """
+ try:
+ json_data = json.loads(json_string)
+ except Exception as exn:
+ raise SchemaParseException(
+ f'Error parsing schema from JSON: {json_string!r}. '
+ f'Error message: {exn!r}.') from exn
+
+ # Initialize the names object
+ names = Names()
+
+ # construct the Avro Schema object
+ return schema_from_json_data(json_data, names)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client.py
new file mode 100644
index 00000000..9dc8d2ec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client.py
@@ -0,0 +1,458 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import logging
+import uuid
+from typing import (
+ Any,
+ cast,
+ Dict,
+ Iterator,
+ Optional,
+ Tuple,
+ TYPE_CHECKING,
+ Union,
+)
+from urllib.parse import parse_qs, quote
+
+from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential
+from azure.core.exceptions import HttpResponseError
+from azure.core.pipeline import Pipeline
+from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module
+from azure.core.pipeline.policies import (
+ AzureSasCredentialPolicy,
+ ContentDecodePolicy,
+ DistributedTracingPolicy,
+ HttpLoggingPolicy,
+ ProxyPolicy,
+ RedirectPolicy,
+ UserAgentPolicy,
+)
+
+from .authentication import SharedKeyCredentialPolicy
+from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE
+from .models import LocationMode, StorageConfiguration
+from .policies import (
+ ExponentialRetry,
+ QueueMessagePolicy,
+ StorageBearerTokenCredentialPolicy,
+ StorageContentValidation,
+ StorageHeadersPolicy,
+ StorageHosts,
+ StorageLoggingPolicy,
+ StorageRequestHook,
+ StorageResponseHook,
+)
+from .request_handlers import serialize_batch_body, _get_batch_request_delimiter
+from .response_handlers import PartialBatchErrorException, process_storage_error
+from .shared_access_signature import QueryStringConstants
+from .._version import VERSION
+from .._shared_access_signature import _is_credential_sastoken
+
+if TYPE_CHECKING:
+ from azure.core.credentials_async import AsyncTokenCredential
+ from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756
+
+_LOGGER = logging.getLogger(__name__)
+_SERVICE_PARAMS = {
+ "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"},
+ "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"},
+ "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"},
+ "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"},
+}
+
+
+class StorageAccountHostsMixin(object):
+ _client: Any
+ def __init__(
+ self,
+ parsed_url: Any,
+ service: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> None:
+ self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
+ self._hosts = kwargs.get("_hosts")
+ self.scheme = parsed_url.scheme
+ self._is_localhost = False
+
+ if service not in ["blob", "queue", "file-share", "dfs"]:
+ raise ValueError(f"Invalid service: {service}")
+ service_name = service.split('-')[0]
+ account = parsed_url.netloc.split(f".{service_name}.core.")
+
+ self.account_name = account[0] if len(account) > 1 else None
+ if not self.account_name and parsed_url.netloc.startswith("localhost") \
+ or parsed_url.netloc.startswith("127.0.0.1"):
+ self._is_localhost = True
+ self.account_name = parsed_url.path.strip("/")
+
+ self.credential = _format_shared_key_credential(self.account_name, credential)
+ if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"):
+ raise ValueError("Token credential is only supported with HTTPS.")
+
+ secondary_hostname = None
+ if hasattr(self.credential, "account_name"):
+ self.account_name = self.credential.account_name
+ secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}"
+
+ if not self._hosts:
+ if len(account) > 1:
+ secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary")
+ if kwargs.get("secondary_hostname"):
+ secondary_hostname = kwargs["secondary_hostname"]
+ primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
+ self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}
+
+ self._sdk_moniker = f"storage-{service}/{VERSION}"
+ self._config, self._pipeline = self._create_pipeline(self.credential, sdk_moniker=self._sdk_moniker, **kwargs)
+
+ def __enter__(self):
+ self._client.__enter__()
+ return self
+
+ def __exit__(self, *args):
+ self._client.__exit__(*args)
+
+ def close(self):
+ """ This method is to close the sockets opened by the client.
+ It need not be used when using with a context manager.
+ """
+ self._client.close()
+
+ @property
+ def url(self):
+ """The full endpoint URL to this entity, including SAS token if used.
+
+ This could be either the primary endpoint,
+ or the secondary endpoint depending on the current :func:`location_mode`.
+ :returns: The full endpoint URL to this entity, including SAS token if used.
+ :rtype: str
+ """
+ return self._format_url(self._hosts[self._location_mode])
+
+ @property
+ def primary_endpoint(self):
+ """The full primary endpoint URL.
+
+ :rtype: str
+ """
+ return self._format_url(self._hosts[LocationMode.PRIMARY])
+
+ @property
+ def primary_hostname(self):
+ """The hostname of the primary endpoint.
+
+ :rtype: str
+ """
+ return self._hosts[LocationMode.PRIMARY]
+
+ @property
+ def secondary_endpoint(self):
+ """The full secondary endpoint URL if configured.
+
+ If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional
+ `secondary_hostname` keyword argument on instantiation.
+
+ :rtype: str
+ :raise ValueError:
+ """
+ if not self._hosts[LocationMode.SECONDARY]:
+ raise ValueError("No secondary host configured.")
+ return self._format_url(self._hosts[LocationMode.SECONDARY])
+
+ @property
+ def secondary_hostname(self):
+ """The hostname of the secondary endpoint.
+
+ If not available this will be None. To explicitly specify a secondary hostname, use the optional
+ `secondary_hostname` keyword argument on instantiation.
+
+ :rtype: Optional[str]
+ """
+ return self._hosts[LocationMode.SECONDARY]
+
+ @property
+ def location_mode(self):
+ """The location mode that the client is currently using.
+
+ By default this will be "primary". Options include "primary" and "secondary".
+
+ :rtype: str
+ """
+
+ return self._location_mode
+
+ @location_mode.setter
+ def location_mode(self, value):
+ if self._hosts.get(value):
+ self._location_mode = value
+ self._client._config.url = self.url # pylint: disable=protected-access
+ else:
+ raise ValueError(f"No host URL for location mode: {value}")
+
+ @property
+ def api_version(self):
+ """The version of the Storage API used for requests.
+
+ :rtype: str
+ """
+ return self._client._config.version # pylint: disable=protected-access
+
+ def _format_query_string(
+ self, sas_token: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long
+ snapshot: Optional[str] = None,
+ share_snapshot: Optional[str] = None
+ ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long
+ query_str = "?"
+ if snapshot:
+ query_str += f"snapshot={snapshot}&"
+ if share_snapshot:
+ query_str += f"sharesnapshot={share_snapshot}&"
+ if sas_token and isinstance(credential, AzureSasCredential):
+ raise ValueError(
+ "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
+ if _is_credential_sastoken(credential):
+ credential = cast(str, credential)
+ query_str += credential.lstrip("?")
+ credential = None
+ elif sas_token:
+ query_str += sas_token
+ return query_str.rstrip("?&"), credential
+
+ def _create_pipeline(
+ self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> Tuple[StorageConfiguration, Pipeline]:
+ self._credential_policy: Any = None
+ if hasattr(credential, "get_token"):
+ if kwargs.get('audience'):
+ audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE
+ else:
+ audience = STORAGE_OAUTH_SCOPE
+ self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience)
+ elif isinstance(credential, SharedKeyCredentialPolicy):
+ self._credential_policy = credential
+ elif isinstance(credential, AzureSasCredential):
+ self._credential_policy = AzureSasCredentialPolicy(credential)
+ elif credential is not None:
+ raise TypeError(f"Unsupported credential: {type(credential)}")
+
+ config = kwargs.get("_configuration") or create_configuration(**kwargs)
+ if kwargs.get("_pipeline"):
+ return config, kwargs["_pipeline"]
+ transport = kwargs.get("transport")
+ kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
+ kwargs.setdefault("read_timeout", READ_TIMEOUT)
+ if not transport:
+ transport = RequestsTransport(**kwargs)
+ policies = [
+ QueueMessagePolicy(),
+ config.proxy_policy,
+ config.user_agent_policy,
+ StorageContentValidation(),
+ ContentDecodePolicy(response_encoding="utf-8"),
+ RedirectPolicy(**kwargs),
+ StorageHosts(hosts=self._hosts, **kwargs),
+ config.retry_policy,
+ config.headers_policy,
+ StorageRequestHook(**kwargs),
+ self._credential_policy,
+ config.logging_policy,
+ StorageResponseHook(**kwargs),
+ DistributedTracingPolicy(**kwargs),
+ HttpLoggingPolicy(**kwargs)
+ ]
+ if kwargs.get("_additional_pipeline_policies"):
+ policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore
+ config.transport = transport # type: ignore
+ return config, Pipeline(transport, policies=policies)
+
+ def _batch_send(
+ self,
+ *reqs: "HttpRequest",
+ **kwargs: Any
+ ) -> Iterator["HttpResponse"]:
+ """Given a series of request, do a Storage batch call.
+
+ :param HttpRequest reqs: A collection of HttpRequest objects.
+ :returns: An iterator of HttpResponse objects.
+ :rtype: Iterator[HttpResponse]
+ """
+ # Pop it here, so requests doesn't feel bad about additional kwarg
+ raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
+ batch_id = str(uuid.uuid1())
+
+ request = self._client._client.post( # pylint: disable=protected-access
+ url=(
+ f'{self.scheme}://{self.primary_hostname}/'
+ f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}"
+ f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}"
+ ),
+ headers={
+ 'x-ms-version': self.api_version,
+ "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False)
+ }
+ )
+
+ policies = [StorageHeadersPolicy()]
+ if self._credential_policy:
+ policies.append(self._credential_policy)
+
+ request.set_multipart_mixed(
+ *reqs,
+ policies=policies,
+ enforce_https=False
+ )
+
+ Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access
+ body = serialize_batch_body(request.multipart_mixed_info[0], batch_id)
+ request.set_bytes_body(body)
+
+ temp = request.multipart_mixed_info
+ request.multipart_mixed_info = None
+ pipeline_response = self._pipeline.run(
+ request, **kwargs
+ )
+ response = pipeline_response.http_response
+ request.multipart_mixed_info = temp
+
+ try:
+ if response.status_code not in [202]:
+ raise HttpResponseError(response=response)
+ parts = response.parts()
+ if raise_on_any_failure:
+ parts = list(response.parts())
+ if any(p for p in parts if not 200 <= p.status_code < 300):
+ error = PartialBatchErrorException(
+ message="There is a partial failure in the batch operation.",
+ response=response, parts=parts
+ )
+ raise error
+ return iter(parts)
+ return parts # type: ignore [no-any-return]
+ except HttpResponseError as error:
+ process_storage_error(error)
+
+
+class TransportWrapper(HttpTransport):
+ """Wrapper class that ensures that an inner client created
+ by a `get_client` method does not close the outer transport for the parent
+ when used in a context manager.
+ """
+ def __init__(self, transport):
+ self._transport = transport
+
+ def send(self, request, **kwargs):
+ return self._transport.send(request, **kwargs)
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
+
+
+def _format_shared_key_credential(
+ account_name: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long
+) -> Any:
+ if isinstance(credential, str):
+ if not account_name:
+ raise ValueError("Unable to determine account name for shared key credential.")
+ credential = {"account_name": account_name, "account_key": credential}
+ if isinstance(credential, dict):
+ if "account_name" not in credential:
+ raise ValueError("Shared key credential missing 'account_name")
+ if "account_key" not in credential:
+ raise ValueError("Shared key credential missing 'account_key")
+ return SharedKeyCredentialPolicy(**credential)
+ if isinstance(credential, AzureNamedKeyCredential):
+ return SharedKeyCredentialPolicy(credential.named_key.name, credential.named_key.key)
+ return credential
+
+
+def parse_connection_str(
+ conn_str: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]],
+ service: str
+) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long
+ conn_str = conn_str.rstrip(";")
+ conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")]
+ if any(len(tup) != 2 for tup in conn_settings_list):
+ raise ValueError("Connection string is either blank or malformed.")
+ conn_settings = dict((key.upper(), val) for key, val in conn_settings_list)
+ endpoints = _SERVICE_PARAMS[service]
+ primary = None
+ secondary = None
+ if not credential:
+ try:
+ credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]}
+ except KeyError:
+ credential = conn_settings.get("SHAREDACCESSSIGNATURE")
+ if endpoints["primary"] in conn_settings:
+ primary = conn_settings[endpoints["primary"]]
+ if endpoints["secondary"] in conn_settings:
+ secondary = conn_settings[endpoints["secondary"]]
+ else:
+ if endpoints["secondary"] in conn_settings:
+ raise ValueError("Connection string specifies only secondary endpoint.")
+ try:
+ primary =(
+ f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://"
+ f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ secondary = (
+ f"{conn_settings['ACCOUNTNAME']}-secondary."
+ f"{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ except KeyError:
+ pass
+
+ if not primary:
+ try:
+ primary = (
+ f"https://{conn_settings['ACCOUNTNAME']}."
+ f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}"
+ )
+ except KeyError as exc:
+ raise ValueError("Connection string missing required connection details.") from exc
+ if service == "dfs":
+ primary = primary.replace(".blob.", ".dfs.")
+ if secondary:
+ secondary = secondary.replace(".blob.", ".dfs.")
+ return primary, secondary, credential
+
+
+def create_configuration(**kwargs: Any) -> StorageConfiguration:
+ # Backwards compatibility if someone is not passing sdk_moniker
+ if not kwargs.get("sdk_moniker"):
+ kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}"
+ config = StorageConfiguration(**kwargs)
+ config.headers_policy = StorageHeadersPolicy(**kwargs)
+ config.user_agent_policy = UserAgentPolicy(**kwargs)
+ config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs)
+ config.logging_policy = StorageLoggingPolicy(**kwargs)
+ config.proxy_policy = ProxyPolicy(**kwargs)
+ return config
+
+
+def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]:
+ sas_values = QueryStringConstants.to_list()
+ parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
+ sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values]
+ sas_token = None
+ if sas_params:
+ sas_token = "&".join(sas_params)
+
+ snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot")
+ return snapshot, sas_token
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client_async.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client_async.py
new file mode 100644
index 00000000..6186b29d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/base_client_async.py
@@ -0,0 +1,280 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# mypy: disable-error-code="attr-defined"
+
+import logging
+from typing import Any, cast, Dict, Optional, Tuple, TYPE_CHECKING, Union
+
+from azure.core.async_paging import AsyncList
+from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential
+from azure.core.credentials_async import AsyncTokenCredential
+from azure.core.exceptions import HttpResponseError
+from azure.core.pipeline import AsyncPipeline
+from azure.core.pipeline.policies import (
+ AsyncRedirectPolicy,
+ AzureSasCredentialPolicy,
+ ContentDecodePolicy,
+ DistributedTracingPolicy,
+ HttpLoggingPolicy,
+)
+from azure.core.pipeline.transport import AsyncHttpTransport
+
+from .authentication import SharedKeyCredentialPolicy
+from .base_client import create_configuration
+from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE
+from .models import StorageConfiguration
+from .policies import (
+ QueueMessagePolicy,
+ StorageContentValidation,
+ StorageHeadersPolicy,
+ StorageHosts,
+ StorageRequestHook,
+)
+from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook
+from .response_handlers import PartialBatchErrorException, process_storage_error
+from .._shared_access_signature import _is_credential_sastoken
+
+if TYPE_CHECKING:
+ from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756
+_LOGGER = logging.getLogger(__name__)
+
+_SERVICE_PARAMS = {
+ "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"},
+ "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"},
+ "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"},
+ "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"},
+}
+
+
+class AsyncStorageAccountHostsMixin(object):
+
+ def __enter__(self):
+ raise TypeError("Async client only supports 'async with'.")
+
+ def __exit__(self, *args):
+ pass
+
+ async def __aenter__(self):
+ await self._client.__aenter__()
+ return self
+
+ async def __aexit__(self, *args):
+ await self._client.__aexit__(*args)
+
+ async def close(self):
+ """ This method is to close the sockets opened by the client.
+ It need not be used when using with a context manager.
+ """
+ await self._client.close()
+
+ def _format_query_string(
+ self, sas_token: Optional[str],
+ credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long
+ snapshot: Optional[str] = None,
+ share_snapshot: Optional[str] = None
+ ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long
+ query_str = "?"
+ if snapshot:
+ query_str += f"snapshot={snapshot}&"
+ if share_snapshot:
+ query_str += f"sharesnapshot={share_snapshot}&"
+ if sas_token and isinstance(credential, AzureSasCredential):
+ raise ValueError(
+ "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
+ if _is_credential_sastoken(credential):
+ query_str += credential.lstrip("?") # type: ignore [union-attr]
+ credential = None
+ elif sas_token:
+ query_str += sas_token
+ return query_str.rstrip("?&"), credential
+
+ def _create_pipeline(
+ self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long
+ **kwargs: Any
+ ) -> Tuple[StorageConfiguration, AsyncPipeline]:
+ self._credential_policy: Optional[
+ Union[AsyncStorageBearerTokenCredentialPolicy,
+ SharedKeyCredentialPolicy,
+ AzureSasCredentialPolicy]] = None
+ if hasattr(credential, 'get_token'):
+ if kwargs.get('audience'):
+ audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE
+ else:
+ audience = STORAGE_OAUTH_SCOPE
+ self._credential_policy = AsyncStorageBearerTokenCredentialPolicy(
+ cast(AsyncTokenCredential, credential), audience)
+ elif isinstance(credential, SharedKeyCredentialPolicy):
+ self._credential_policy = credential
+ elif isinstance(credential, AzureSasCredential):
+ self._credential_policy = AzureSasCredentialPolicy(credential)
+ elif credential is not None:
+ raise TypeError(f"Unsupported credential: {type(credential)}")
+ config = kwargs.get('_configuration') or create_configuration(**kwargs)
+ if kwargs.get('_pipeline'):
+ return config, kwargs['_pipeline']
+ transport = kwargs.get('transport')
+ kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
+ kwargs.setdefault("read_timeout", READ_TIMEOUT)
+ if not transport:
+ try:
+ from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import
+ except ImportError as exc:
+ raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc
+ transport = AioHttpTransport(**kwargs)
+ hosts = self._hosts
+ policies = [
+ QueueMessagePolicy(),
+ config.proxy_policy,
+ config.user_agent_policy,
+ StorageContentValidation(),
+ ContentDecodePolicy(response_encoding="utf-8"),
+ AsyncRedirectPolicy(**kwargs),
+ StorageHosts(hosts=hosts, **kwargs),
+ config.retry_policy,
+ config.headers_policy,
+ StorageRequestHook(**kwargs),
+ self._credential_policy,
+ config.logging_policy,
+ AsyncStorageResponseHook(**kwargs),
+ DistributedTracingPolicy(**kwargs),
+ HttpLoggingPolicy(**kwargs),
+ ]
+ if kwargs.get("_additional_pipeline_policies"):
+ policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore
+ config.transport = transport #type: ignore
+ return config, AsyncPipeline(transport, policies=policies) #type: ignore
+
+ async def _batch_send(
+ self,
+ *reqs: "HttpRequest",
+ **kwargs: Any
+ ) -> AsyncList["HttpResponse"]:
+ """Given a series of request, do a Storage batch call.
+
+ :param HttpRequest reqs: A collection of HttpRequest objects.
+ :returns: An AsyncList of HttpResponse objects.
+ :rtype: AsyncList[HttpResponse]
+ """
+ # Pop it here, so requests doesn't feel bad about additional kwarg
+ raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
+ request = self._client._client.post( # pylint: disable=protected-access
+ url=(
+ f'{self.scheme}://{self.primary_hostname}/'
+ f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}"
+ f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}"
+ ),
+ headers={
+ 'x-ms-version': self.api_version
+ }
+ )
+
+ policies = [StorageHeadersPolicy()]
+ if self._credential_policy:
+ policies.append(self._credential_policy) # type: ignore
+
+ request.set_multipart_mixed(
+ *reqs,
+ policies=policies,
+ enforce_https=False
+ )
+
+ pipeline_response = await self._pipeline.run(
+ request, **kwargs
+ )
+ response = pipeline_response.http_response
+
+ try:
+ if response.status_code not in [202]:
+ raise HttpResponseError(response=response)
+ parts = response.parts() # Return an AsyncIterator
+ if raise_on_any_failure:
+ parts_list = []
+ async for part in parts:
+ parts_list.append(part)
+ if any(p for p in parts_list if not 200 <= p.status_code < 300):
+ error = PartialBatchErrorException(
+ message="There is a partial failure in the batch operation.",
+ response=response, parts=parts_list
+ )
+ raise error
+ return AsyncList(parts_list)
+ return parts # type: ignore [no-any-return]
+ except HttpResponseError as error:
+ process_storage_error(error)
+
+def parse_connection_str(
+ conn_str: str,
+ credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]],
+ service: str
+) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long
+ conn_str = conn_str.rstrip(";")
+ conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")]
+ if any(len(tup) != 2 for tup in conn_settings_list):
+ raise ValueError("Connection string is either blank or malformed.")
+ conn_settings = dict((key.upper(), val) for key, val in conn_settings_list)
+ endpoints = _SERVICE_PARAMS[service]
+ primary = None
+ secondary = None
+ if not credential:
+ try:
+ credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]}
+ except KeyError:
+ credential = conn_settings.get("SHAREDACCESSSIGNATURE")
+ if endpoints["primary"] in conn_settings:
+ primary = conn_settings[endpoints["primary"]]
+ if endpoints["secondary"] in conn_settings:
+ secondary = conn_settings[endpoints["secondary"]]
+ else:
+ if endpoints["secondary"] in conn_settings:
+ raise ValueError("Connection string specifies only secondary endpoint.")
+ try:
+ primary =(
+ f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://"
+ f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ secondary = (
+ f"{conn_settings['ACCOUNTNAME']}-secondary."
+ f"{service}.{conn_settings['ENDPOINTSUFFIX']}"
+ )
+ except KeyError:
+ pass
+
+ if not primary:
+ try:
+ primary = (
+ f"https://{conn_settings['ACCOUNTNAME']}."
+ f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}"
+ )
+ except KeyError as exc:
+ raise ValueError("Connection string missing required connection details.") from exc
+ if service == "dfs":
+ primary = primary.replace(".blob.", ".dfs.")
+ if secondary:
+ secondary = secondary.replace(".blob.", ".dfs.")
+ return primary, secondary, credential
+
+class AsyncTransportWrapper(AsyncHttpTransport):
+ """Wrapper class that ensures that an inner client created
+ by a `get_client` method does not close the outer transport for the parent
+ when used in a context manager.
+ """
+ def __init__(self, async_transport):
+ self._transport = async_transport
+
+ async def send(self, request, **kwargs):
+ return await self._transport.send(request, **kwargs)
+
+ async def open(self):
+ pass
+
+ async def close(self):
+ pass
+
+ async def __aenter__(self):
+ pass
+
+ async def __aexit__(self, *args):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/constants.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/constants.py
new file mode 100644
index 00000000..0b4b029a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/constants.py
@@ -0,0 +1,19 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from .._serialize import _SUPPORTED_API_VERSIONS
+
+
+X_MS_VERSION = _SUPPORTED_API_VERSIONS[-1]
+
+# Default socket timeouts, in seconds
+CONNECTION_TIMEOUT = 20
+READ_TIMEOUT = 60
+
+DEFAULT_OAUTH_SCOPE = "/.default"
+STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default"
+
+SERVICE_HOST_BASE = 'core.windows.net'
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/models.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/models.py
new file mode 100644
index 00000000..d78cd911
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/models.py
@@ -0,0 +1,585 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=too-many-instance-attributes
+from enum import Enum
+from typing import Optional
+
+from azure.core import CaseInsensitiveEnumMeta
+from azure.core.configuration import Configuration
+from azure.core.pipeline.policies import UserAgentPolicy
+
+
+def get_enum_value(value):
+ if value is None or value in ["None", ""]:
+ return None
+ try:
+ return value.value
+ except AttributeError:
+ return value
+
+
+class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta):
+
+ # Generic storage values
+ ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists"
+ ACCOUNT_BEING_CREATED = "AccountBeingCreated"
+ ACCOUNT_IS_DISABLED = "AccountIsDisabled"
+ AUTHENTICATION_FAILED = "AuthenticationFailed"
+ AUTHORIZATION_FAILURE = "AuthorizationFailure"
+ NO_AUTHENTICATION_INFORMATION = "NoAuthenticationInformation"
+ CONDITION_HEADERS_NOT_SUPPORTED = "ConditionHeadersNotSupported"
+ CONDITION_NOT_MET = "ConditionNotMet"
+ EMPTY_METADATA_KEY = "EmptyMetadataKey"
+ INSUFFICIENT_ACCOUNT_PERMISSIONS = "InsufficientAccountPermissions"
+ INTERNAL_ERROR = "InternalError"
+ INVALID_AUTHENTICATION_INFO = "InvalidAuthenticationInfo"
+ INVALID_HEADER_VALUE = "InvalidHeaderValue"
+ INVALID_HTTP_VERB = "InvalidHttpVerb"
+ INVALID_INPUT = "InvalidInput"
+ INVALID_MD5 = "InvalidMd5"
+ INVALID_METADATA = "InvalidMetadata"
+ INVALID_QUERY_PARAMETER_VALUE = "InvalidQueryParameterValue"
+ INVALID_RANGE = "InvalidRange"
+ INVALID_RESOURCE_NAME = "InvalidResourceName"
+ INVALID_URI = "InvalidUri"
+ INVALID_XML_DOCUMENT = "InvalidXmlDocument"
+ INVALID_XML_NODE_VALUE = "InvalidXmlNodeValue"
+ MD5_MISMATCH = "Md5Mismatch"
+ METADATA_TOO_LARGE = "MetadataTooLarge"
+ MISSING_CONTENT_LENGTH_HEADER = "MissingContentLengthHeader"
+ MISSING_REQUIRED_QUERY_PARAMETER = "MissingRequiredQueryParameter"
+ MISSING_REQUIRED_HEADER = "MissingRequiredHeader"
+ MISSING_REQUIRED_XML_NODE = "MissingRequiredXmlNode"
+ MULTIPLE_CONDITION_HEADERS_NOT_SUPPORTED = "MultipleConditionHeadersNotSupported"
+ OPERATION_TIMED_OUT = "OperationTimedOut"
+ OUT_OF_RANGE_INPUT = "OutOfRangeInput"
+ OUT_OF_RANGE_QUERY_PARAMETER_VALUE = "OutOfRangeQueryParameterValue"
+ REQUEST_BODY_TOO_LARGE = "RequestBodyTooLarge"
+ RESOURCE_TYPE_MISMATCH = "ResourceTypeMismatch"
+ REQUEST_URL_FAILED_TO_PARSE = "RequestUrlFailedToParse"
+ RESOURCE_ALREADY_EXISTS = "ResourceAlreadyExists"
+ RESOURCE_NOT_FOUND = "ResourceNotFound"
+ SERVER_BUSY = "ServerBusy"
+ UNSUPPORTED_HEADER = "UnsupportedHeader"
+ UNSUPPORTED_XML_NODE = "UnsupportedXmlNode"
+ UNSUPPORTED_QUERY_PARAMETER = "UnsupportedQueryParameter"
+ UNSUPPORTED_HTTP_VERB = "UnsupportedHttpVerb"
+
+ # Blob values
+ APPEND_POSITION_CONDITION_NOT_MET = "AppendPositionConditionNotMet"
+ BLOB_ACCESS_TIER_NOT_SUPPORTED_FOR_ACCOUNT_TYPE = "BlobAccessTierNotSupportedForAccountType"
+ BLOB_ALREADY_EXISTS = "BlobAlreadyExists"
+ BLOB_NOT_FOUND = "BlobNotFound"
+ BLOB_OVERWRITTEN = "BlobOverwritten"
+ BLOB_TIER_INADEQUATE_FOR_CONTENT_LENGTH = "BlobTierInadequateForContentLength"
+ BLOCK_COUNT_EXCEEDS_LIMIT = "BlockCountExceedsLimit"
+ BLOCK_LIST_TOO_LONG = "BlockListTooLong"
+ CANNOT_CHANGE_TO_LOWER_TIER = "CannotChangeToLowerTier"
+ CANNOT_VERIFY_COPY_SOURCE = "CannotVerifyCopySource"
+ CONTAINER_ALREADY_EXISTS = "ContainerAlreadyExists"
+ CONTAINER_BEING_DELETED = "ContainerBeingDeleted"
+ CONTAINER_DISABLED = "ContainerDisabled"
+ CONTAINER_NOT_FOUND = "ContainerNotFound"
+ CONTENT_LENGTH_LARGER_THAN_TIER_LIMIT = "ContentLengthLargerThanTierLimit"
+ COPY_ACROSS_ACCOUNTS_NOT_SUPPORTED = "CopyAcrossAccountsNotSupported"
+ COPY_ID_MISMATCH = "CopyIdMismatch"
+ FEATURE_VERSION_MISMATCH = "FeatureVersionMismatch"
+ INCREMENTAL_COPY_BLOB_MISMATCH = "IncrementalCopyBlobMismatch"
+ INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed"
+ #: Deprecated: Please use INCREMENTAL_COPY_OF_EARLIER_VERSION_SNAPSHOT_NOT_ALLOWED instead.
+ INCREMENTAL_COPY_OF_ERALIER_VERSION_SNAPSHOT_NOT_ALLOWED = "IncrementalCopyOfEarlierVersionSnapshotNotAllowed"
+ INCREMENTAL_COPY_SOURCE_MUST_BE_SNAPSHOT = "IncrementalCopySourceMustBeSnapshot"
+ INFINITE_LEASE_DURATION_REQUIRED = "InfiniteLeaseDurationRequired"
+ INVALID_BLOB_OR_BLOCK = "InvalidBlobOrBlock"
+ INVALID_BLOB_TIER = "InvalidBlobTier"
+ INVALID_BLOB_TYPE = "InvalidBlobType"
+ INVALID_BLOCK_ID = "InvalidBlockId"
+ INVALID_BLOCK_LIST = "InvalidBlockList"
+ INVALID_OPERATION = "InvalidOperation"
+ INVALID_PAGE_RANGE = "InvalidPageRange"
+ INVALID_SOURCE_BLOB_TYPE = "InvalidSourceBlobType"
+ INVALID_SOURCE_BLOB_URL = "InvalidSourceBlobUrl"
+ INVALID_VERSION_FOR_PAGE_BLOB_OPERATION = "InvalidVersionForPageBlobOperation"
+ LEASE_ALREADY_PRESENT = "LeaseAlreadyPresent"
+ LEASE_ALREADY_BROKEN = "LeaseAlreadyBroken"
+ LEASE_ID_MISMATCH_WITH_BLOB_OPERATION = "LeaseIdMismatchWithBlobOperation"
+ LEASE_ID_MISMATCH_WITH_CONTAINER_OPERATION = "LeaseIdMismatchWithContainerOperation"
+ LEASE_ID_MISMATCH_WITH_LEASE_OPERATION = "LeaseIdMismatchWithLeaseOperation"
+ LEASE_ID_MISSING = "LeaseIdMissing"
+ LEASE_IS_BREAKING_AND_CANNOT_BE_ACQUIRED = "LeaseIsBreakingAndCannotBeAcquired"
+ LEASE_IS_BREAKING_AND_CANNOT_BE_CHANGED = "LeaseIsBreakingAndCannotBeChanged"
+ LEASE_IS_BROKEN_AND_CANNOT_BE_RENEWED = "LeaseIsBrokenAndCannotBeRenewed"
+ LEASE_LOST = "LeaseLost"
+ LEASE_NOT_PRESENT_WITH_BLOB_OPERATION = "LeaseNotPresentWithBlobOperation"
+ LEASE_NOT_PRESENT_WITH_CONTAINER_OPERATION = "LeaseNotPresentWithContainerOperation"
+ LEASE_NOT_PRESENT_WITH_LEASE_OPERATION = "LeaseNotPresentWithLeaseOperation"
+ MAX_BLOB_SIZE_CONDITION_NOT_MET = "MaxBlobSizeConditionNotMet"
+ NO_PENDING_COPY_OPERATION = "NoPendingCopyOperation"
+ OPERATION_NOT_ALLOWED_ON_INCREMENTAL_COPY_BLOB = "OperationNotAllowedOnIncrementalCopyBlob"
+ PENDING_COPY_OPERATION = "PendingCopyOperation"
+ PREVIOUS_SNAPSHOT_CANNOT_BE_NEWER = "PreviousSnapshotCannotBeNewer"
+ PREVIOUS_SNAPSHOT_NOT_FOUND = "PreviousSnapshotNotFound"
+ PREVIOUS_SNAPSHOT_OPERATION_NOT_SUPPORTED = "PreviousSnapshotOperationNotSupported"
+ SEQUENCE_NUMBER_CONDITION_NOT_MET = "SequenceNumberConditionNotMet"
+ SEQUENCE_NUMBER_INCREMENT_TOO_LARGE = "SequenceNumberIncrementTooLarge"
+ SNAPSHOT_COUNT_EXCEEDED = "SnapshotCountExceeded"
+ SNAPSHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded"
+ #: Deprecated: Please use SNAPSHOT_OPERATION_RATE_EXCEEDED instead.
+ SNAPHOT_OPERATION_RATE_EXCEEDED = "SnapshotOperationRateExceeded"
+ SNAPSHOTS_PRESENT = "SnapshotsPresent"
+ SOURCE_CONDITION_NOT_MET = "SourceConditionNotMet"
+ SYSTEM_IN_USE = "SystemInUse"
+ TARGET_CONDITION_NOT_MET = "TargetConditionNotMet"
+ UNAUTHORIZED_BLOB_OVERWRITE = "UnauthorizedBlobOverwrite"
+ BLOB_BEING_REHYDRATED = "BlobBeingRehydrated"
+ BLOB_ARCHIVED = "BlobArchived"
+ BLOB_NOT_ARCHIVED = "BlobNotArchived"
+
+ # Queue values
+ INVALID_MARKER = "InvalidMarker"
+ MESSAGE_NOT_FOUND = "MessageNotFound"
+ MESSAGE_TOO_LARGE = "MessageTooLarge"
+ POP_RECEIPT_MISMATCH = "PopReceiptMismatch"
+ QUEUE_ALREADY_EXISTS = "QueueAlreadyExists"
+ QUEUE_BEING_DELETED = "QueueBeingDeleted"
+ QUEUE_DISABLED = "QueueDisabled"
+ QUEUE_NOT_EMPTY = "QueueNotEmpty"
+ QUEUE_NOT_FOUND = "QueueNotFound"
+
+ # File values
+ CANNOT_DELETE_FILE_OR_DIRECTORY = "CannotDeleteFileOrDirectory"
+ CLIENT_CACHE_FLUSH_DELAY = "ClientCacheFlushDelay"
+ DELETE_PENDING = "DeletePending"
+ DIRECTORY_NOT_EMPTY = "DirectoryNotEmpty"
+ FILE_LOCK_CONFLICT = "FileLockConflict"
+ FILE_SHARE_PROVISIONED_BANDWIDTH_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedBandwidthDowngradeNotAllowed"
+ FILE_SHARE_PROVISIONED_IOPS_DOWNGRADE_NOT_ALLOWED = "FileShareProvisionedIopsDowngradeNotAllowed"
+ INVALID_FILE_OR_DIRECTORY_PATH_NAME = "InvalidFileOrDirectoryPathName"
+ PARENT_NOT_FOUND = "ParentNotFound"
+ READ_ONLY_ATTRIBUTE = "ReadOnlyAttribute"
+ SHARE_ALREADY_EXISTS = "ShareAlreadyExists"
+ SHARE_BEING_DELETED = "ShareBeingDeleted"
+ SHARE_DISABLED = "ShareDisabled"
+ SHARE_NOT_FOUND = "ShareNotFound"
+ SHARING_VIOLATION = "SharingViolation"
+ SHARE_SNAPSHOT_IN_PROGRESS = "ShareSnapshotInProgress"
+ SHARE_SNAPSHOT_COUNT_EXCEEDED = "ShareSnapshotCountExceeded"
+ SHARE_SNAPSHOT_OPERATION_NOT_SUPPORTED = "ShareSnapshotOperationNotSupported"
+ SHARE_HAS_SNAPSHOTS = "ShareHasSnapshots"
+ CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed"
+
+ # DataLake values
+ CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero'
+ PATH_ALREADY_EXISTS = 'PathAlreadyExists'
+ INVALID_FLUSH_POSITION = 'InvalidFlushPosition'
+ INVALID_PROPERTY_NAME = 'InvalidPropertyName'
+ INVALID_SOURCE_URI = 'InvalidSourceUri'
+ UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion'
+ FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound'
+ PATH_NOT_FOUND = 'PathNotFound'
+ RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound'
+ SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound'
+ DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted'
+ FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists'
+ FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted'
+ INVALID_DESTINATION_PATH = 'InvalidDestinationPath'
+ INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath'
+ INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType'
+ LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken'
+ LEASE_NAME_MISMATCH = 'LeaseNameMismatch'
+ PATH_CONFLICT = 'PathConflict'
+ SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted'
+
+
+class DictMixin(object):
+
+ def __setitem__(self, key, item):
+ self.__dict__[key] = item
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __repr__(self):
+ return str(self)
+
+ def __len__(self):
+ return len(self.keys())
+
+ def __delitem__(self, key):
+ self.__dict__[key] = None
+
+ # Compare objects by comparing all attributes.
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ return False
+
+ # Compare objects by comparing all attributes.
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')})
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def has_key(self, k):
+ return k in self.__dict__
+
+ def update(self, *args, **kwargs):
+ return self.__dict__.update(*args, **kwargs)
+
+ def keys(self):
+ return [k for k in self.__dict__ if not k.startswith('_')]
+
+ def values(self):
+ return [v for k, v in self.__dict__.items() if not k.startswith('_')]
+
+ def items(self):
+ return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')]
+
+ def get(self, key, default=None):
+ if key in self.__dict__:
+ return self.__dict__[key]
+ return default
+
+
+class LocationMode(object):
+ """
+ Specifies the location the request should be sent to. This mode only applies
+ for RA-GRS accounts which allow secondary read access. All other account types
+ must use PRIMARY.
+ """
+
+ PRIMARY = 'primary' #: Requests should be sent to the primary location.
+ SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible.
+
+
+class ResourceTypes(object):
+ """
+ Specifies the resource types that are accessible with the account SAS.
+
+ :param bool service:
+ Access to service-level APIs (e.g., Get/Set Service Properties,
+ Get Service Stats, List Containers/Queues/Shares)
+ :param bool container:
+ Access to container-level APIs (e.g., Create/Delete Container,
+ Create/Delete Queue, Create/Delete Share,
+ List Blobs/Files and Directories)
+ :param bool object:
+ Access to object-level APIs for blobs, queue messages, and
+ files(e.g. Put Blob, Query Entity, Get Messages, Create File, etc.)
+ """
+
+ service: bool = False
+ container: bool = False
+ object: bool = False
+ _str: str
+
+ def __init__(
+ self,
+ service: bool = False,
+ container: bool = False,
+ object: bool = False # pylint: disable=redefined-builtin
+ ) -> None:
+ self.service = service
+ self.container = container
+ self.object = object
+ self._str = (('s' if self.service else '') +
+ ('c' if self.container else '') +
+ ('o' if self.object else ''))
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, string):
+ """Create a ResourceTypes from a string.
+
+ To specify service, container, or object you need only to
+ include the first letter of the word in the string. E.g. service and container,
+ you would provide a string "sc".
+
+ :param str string: Specify service, container, or object in
+ in the string with the first letter of the word.
+ :return: A ResourceTypes object
+ :rtype: ~azure.storage.blob.ResourceTypes
+ """
+ res_service = 's' in string
+ res_container = 'c' in string
+ res_object = 'o' in string
+
+ parsed = cls(res_service, res_container, res_object)
+ parsed._str = string
+ return parsed
+
+
+class AccountSasPermissions(object):
+ """
+ :class:`~ResourceTypes` class to be used with generate_account_sas
+ function and for the AccessPolicies used with set_*_acl. There are two types of
+ SAS which may be used to grant resource access. One is to grant access to a
+ specific resource (resource-specific). Another is to grant access to the
+ entire service for a specific account and allow certain operations based on
+ perms found here.
+
+ :param bool read:
+ Valid for all signed resources types (Service, Container, and Object).
+ Permits read permissions to the specified resource type.
+ :param bool write:
+ Valid for all signed resources types (Service, Container, and Object).
+ Permits write permissions to the specified resource type.
+ :param bool delete:
+ Valid for Container and Object resource types, except for queue messages.
+ :param bool delete_previous_version:
+ Delete the previous blob version for the versioning enabled storage account.
+ :param bool list:
+ Valid for Service and Container resource types only.
+ :param bool add:
+ Valid for the following Object resource types only: queue messages, and append blobs.
+ :param bool create:
+ Valid for the following Object resource types only: blobs and files.
+ Users can create new blobs or files, but may not overwrite existing
+ blobs or files.
+ :param bool update:
+ Valid for the following Object resource types only: queue messages.
+ :param bool process:
+ Valid for the following Object resource type only: queue messages.
+ :keyword bool tag:
+ To enable set or get tags on the blobs in the container.
+ :keyword bool filter_by_tags:
+ To enable get blobs by tags, this should be used together with list permission.
+ :keyword bool set_immutability_policy:
+ To enable operations related to set/delete immutability policy.
+ To get immutability policy, you just need read permission.
+ :keyword bool permanent_delete:
+ To enable permanent delete on the blob is permitted.
+ Valid for Object resource type of Blob only.
+ """
+
+ read: bool = False
+ write: bool = False
+ delete: bool = False
+ delete_previous_version: bool = False
+ list: bool = False
+ add: bool = False
+ create: bool = False
+ update: bool = False
+ process: bool = False
+ tag: bool = False
+ filter_by_tags: bool = False
+ set_immutability_policy: bool = False
+ permanent_delete: bool = False
+
+ def __init__(
+ self,
+ read: bool = False,
+ write: bool = False,
+ delete: bool = False,
+ list: bool = False, # pylint: disable=redefined-builtin
+ add: bool = False,
+ create: bool = False,
+ update: bool = False,
+ process: bool = False,
+ delete_previous_version: bool = False,
+ **kwargs
+ ) -> None:
+ self.read = read
+ self.write = write
+ self.delete = delete
+ self.delete_previous_version = delete_previous_version
+ self.permanent_delete = kwargs.pop('permanent_delete', False)
+ self.list = list
+ self.add = add
+ self.create = create
+ self.update = update
+ self.process = process
+ self.tag = kwargs.pop('tag', False)
+ self.filter_by_tags = kwargs.pop('filter_by_tags', False)
+ self.set_immutability_policy = kwargs.pop('set_immutability_policy', False)
+ self._str = (('r' if self.read else '') +
+ ('w' if self.write else '') +
+ ('d' if self.delete else '') +
+ ('x' if self.delete_previous_version else '') +
+ ('y' if self.permanent_delete else '') +
+ ('l' if self.list else '') +
+ ('a' if self.add else '') +
+ ('c' if self.create else '') +
+ ('u' if self.update else '') +
+ ('p' if self.process else '') +
+ ('f' if self.filter_by_tags else '') +
+ ('t' if self.tag else '') +
+ ('i' if self.set_immutability_policy else '')
+ )
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, permission):
+ """Create AccountSasPermissions from a string.
+
+ To specify read, write, delete, etc. permissions you need only to
+ include the first letter of the word in the string. E.g. for read and write
+ permissions you would provide a string "rw".
+
+ :param str permission: Specify permissions in
+ the string with the first letter of the word.
+ :return: An AccountSasPermissions object
+ :rtype: ~azure.storage.blob.AccountSasPermissions
+ """
+ p_read = 'r' in permission
+ p_write = 'w' in permission
+ p_delete = 'd' in permission
+ p_delete_previous_version = 'x' in permission
+ p_permanent_delete = 'y' in permission
+ p_list = 'l' in permission
+ p_add = 'a' in permission
+ p_create = 'c' in permission
+ p_update = 'u' in permission
+ p_process = 'p' in permission
+ p_tag = 't' in permission
+ p_filter_by_tags = 'f' in permission
+ p_set_immutability_policy = 'i' in permission
+ parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version,
+ list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag,
+ filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy,
+ permanent_delete=p_permanent_delete)
+
+ return parsed
+
+
+class Services(object):
+ """Specifies the services accessible with the account SAS.
+
+ :keyword bool blob:
+ Access for the `~azure.storage.blob.BlobServiceClient`. Default is False.
+ :keyword bool queue:
+ Access for the `~azure.storage.queue.QueueServiceClient`. Default is False.
+ :keyword bool fileshare:
+ Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False.
+ """
+
+ def __init__(
+ self, *,
+ blob: bool = False,
+ queue: bool = False,
+ fileshare: bool = False
+ ) -> None:
+ self.blob = blob
+ self.queue = queue
+ self.fileshare = fileshare
+ self._str = (('b' if self.blob else '') +
+ ('q' if self.queue else '') +
+ ('f' if self.fileshare else ''))
+
+ def __str__(self):
+ return self._str
+
+ @classmethod
+ def from_string(cls, string):
+ """Create Services from a string.
+
+ To specify blob, queue, or file you need only to
+ include the first letter of the word in the string. E.g. for blob and queue
+ you would provide a string "bq".
+
+ :param str string: Specify blob, queue, or file in
+ in the string with the first letter of the word.
+ :return: A Services object
+ :rtype: ~azure.storage.blob.Services
+ """
+ res_blob = 'b' in string
+ res_queue = 'q' in string
+ res_file = 'f' in string
+
+ parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file)
+ parsed._str = string
+ return parsed
+
+
+class UserDelegationKey(object):
+ """
+ Represents a user delegation key, provided to the user by Azure Storage
+ based on their Azure Active Directory access token.
+
+ The fields are saved as simple strings since the user does not have to interact with this object;
+ to generate an identify SAS, the user can simply pass it to the right API.
+ """
+
+ signed_oid: Optional[str] = None
+ """Object ID of this token."""
+ signed_tid: Optional[str] = None
+ """Tenant ID of the tenant that issued this token."""
+ signed_start: Optional[str] = None
+ """The datetime this token becomes valid."""
+ signed_expiry: Optional[str] = None
+ """The datetime this token expires."""
+ signed_service: Optional[str] = None
+ """What service this key is valid for."""
+ signed_version: Optional[str] = None
+ """The version identifier of the REST service that created this token."""
+ value: Optional[str] = None
+ """The user delegation key."""
+
+ def __init__(self):
+ self.signed_oid = None
+ self.signed_tid = None
+ self.signed_start = None
+ self.signed_expiry = None
+ self.signed_service = None
+ self.signed_version = None
+ self.value = None
+
+
+class StorageConfiguration(Configuration):
+ """
+ Specifies the configurable values used in Azure Storage.
+
+ :param int max_single_put_size: If the blob size is less than or equal max_single_put_size, then the blob will be
+ uploaded with only one http PUT request. If the blob size is larger than max_single_put_size,
+ the blob will be uploaded in chunks. Defaults to 64*1024*1024, or 64MB.
+ :param int copy_polling_interval: The interval in seconds for polling copy operations.
+ :param int max_block_size: The maximum chunk size for uploading a block blob in chunks.
+ Defaults to 4*1024*1024, or 4MB.
+ :param int min_large_block_upload_threshold: The minimum chunk size required to use the memory efficient
+ algorithm when uploading a block blob.
+ :param bool use_byte_buffer: Use a byte buffer for block blob uploads. Defaults to False.
+ :param int max_page_size: The maximum chunk size for uploading a page blob. Defaults to 4*1024*1024, or 4MB.
+ :param int min_large_chunk_upload_threshold: The max size for a single put operation.
+ :param int max_single_get_size: The maximum size for a blob to be downloaded in a single call,
+ the exceeded part will be downloaded in chunks (could be parallel). Defaults to 32*1024*1024, or 32MB.
+ :param int max_chunk_get_size: The maximum chunk size used for downloading a blob. Defaults to 4*1024*1024,
+ or 4MB.
+ :param int max_range_size: The max range size for file upload.
+
+ """
+
+ max_single_put_size: int
+ copy_polling_interval: int
+ max_block_size: int
+ min_large_block_upload_threshold: int
+ use_byte_buffer: bool
+ max_page_size: int
+ min_large_chunk_upload_threshold: int
+ max_single_get_size: int
+ max_chunk_get_size: int
+ max_range_size: int
+ user_agent_policy: UserAgentPolicy
+
+ def __init__(self, **kwargs):
+ super(StorageConfiguration, self).__init__(**kwargs)
+ self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024)
+ self.copy_polling_interval = 15
+ self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024)
+ self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1)
+ self.use_byte_buffer = kwargs.pop('use_byte_buffer', False)
+ self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024)
+ self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1)
+ self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024)
+ self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024)
+ self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/parser.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/parser.py
new file mode 100644
index 00000000..112c1984
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/parser.py
@@ -0,0 +1,53 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from datetime import datetime, timezone
+from typing import Optional
+
+EPOCH_AS_FILETIME = 116444736000000000 # January 1, 1970 as MS filetime
+HUNDREDS_OF_NANOSECONDS = 10000000
+
+
+def _to_utc_datetime(value: datetime) -> str:
+ return value.strftime('%Y-%m-%dT%H:%M:%SZ')
+
+
+def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]:
+ """Converts an RFC 1123 date string to a UTC datetime.
+
+ :param str rfc_1123: The time and date in RFC 1123 format.
+ :returns: The time and date in UTC datetime format.
+ :rtype: datetime
+ """
+ if not rfc_1123:
+ return None
+
+ return datetime.strptime(rfc_1123, "%a, %d %b %Y %H:%M:%S %Z")
+
+
+def _filetime_to_datetime(filetime: str) -> Optional[datetime]:
+ """Converts an MS filetime string to a UTC datetime. "0" indicates None.
+ If parsing MS Filetime fails, tries RFC 1123 as backup.
+
+ :param str filetime: The time and date in MS filetime format.
+ :returns: The time and date in UTC datetime format.
+ :rtype: datetime
+ """
+ if not filetime:
+ return None
+
+ # Try to convert to MS Filetime
+ try:
+ temp_filetime = int(filetime)
+ if temp_filetime == 0:
+ return None
+
+ return datetime.fromtimestamp((temp_filetime - EPOCH_AS_FILETIME) / HUNDREDS_OF_NANOSECONDS, tz=timezone.utc)
+ except ValueError:
+ pass
+
+ # Try RFC 1123 as backup
+ return _rfc_1123_to_datetime(filetime)
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies.py
new file mode 100644
index 00000000..ee75cd5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies.py
@@ -0,0 +1,694 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import base64
+import hashlib
+import logging
+import random
+import re
+import uuid
+from io import SEEK_SET, UnsupportedOperation
+from time import time
+from typing import Any, Dict, Optional, TYPE_CHECKING
+from urllib.parse import (
+ parse_qsl,
+ urlencode,
+ urlparse,
+ urlunparse,
+)
+from wsgiref.handlers import format_date_time
+
+from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
+from azure.core.pipeline.policies import (
+ BearerTokenCredentialPolicy,
+ HeadersPolicy,
+ HTTPPolicy,
+ NetworkTraceLoggingPolicy,
+ RequestHistory,
+ SansIOHTTPPolicy
+)
+
+from .authentication import AzureSigningError, StorageHttpChallenge
+from .constants import DEFAULT_OAUTH_SCOPE
+from .models import LocationMode
+
+if TYPE_CHECKING:
+ from azure.core.credentials import TokenCredential
+ from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
+ PipelineRequest,
+ PipelineResponse
+ )
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def encode_base64(data):
+ if isinstance(data, str):
+ data = data.encode('utf-8')
+ encoded = base64.b64encode(data)
+ return encoded.decode('utf-8')
+
+
+# Are we out of retries?
+def is_exhausted(settings):
+ retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status'])
+ retry_counts = list(filter(None, retry_counts))
+ if not retry_counts:
+ return False
+ return min(retry_counts) < 0
+
+
+def retry_hook(settings, **kwargs):
+ if settings['hook']:
+ settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs)
+
+
+# Is this method/status code retryable? (Based on allowlists and control
+# variables such as the number of total retries to allow, whether to
+# respect the Retry-After header, whether this header is present, and
+# whether the returned status code is on the list of status codes to
+# be retried upon on the presence of the aforementioned header)
+def is_retry(response, mode):
+ status = response.http_response.status_code
+ if 300 <= status < 500:
+ # An exception occurred, but in most cases it was expected. Examples could
+ # include a 309 Conflict or 412 Precondition Failed.
+ if status == 404 and mode == LocationMode.SECONDARY:
+ # Response code 404 should be retried if secondary was used.
+ return True
+ if status == 408:
+ # Response code 408 is a timeout and should be retried.
+ return True
+ return False
+ if status >= 500:
+ # Response codes above 500 with the exception of 501 Not Implemented and
+ # 505 Version Not Supported indicate a server issue and should be retried.
+ if status in [501, 505]:
+ return False
+ return True
+ return False
+
+
+def is_checksum_retry(response):
+ # retry if invalid content md5
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ computed_md5 = response.http_request.headers.get('content-md5', None) or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ return True
+ return False
+
+
+def urljoin(base_url, stub_url):
+ parsed = urlparse(base_url)
+ parsed = parsed._replace(path=parsed.path + '/' + stub_url)
+ return parsed.geturl()
+
+
+class QueueMessagePolicy(SansIOHTTPPolicy):
+
+ def on_request(self, request):
+ message_id = request.context.options.pop('queue_message_id', None)
+ if message_id:
+ request.http_request.url = urljoin(
+ request.http_request.url,
+ message_id)
+
+
+class StorageHeadersPolicy(HeadersPolicy):
+ request_id_header_name = 'x-ms-client-request-id'
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ super(StorageHeadersPolicy, self).on_request(request)
+ current_time = format_date_time(time())
+ request.http_request.headers['x-ms-date'] = current_time
+
+ custom_id = request.context.options.pop('client_request_id', None)
+ request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1())
+
+ # def on_response(self, request, response):
+ # # raise exception if the echoed client request id from the service is not identical to the one we sent
+ # if self.request_id_header_name in response.http_response.headers:
+
+ # client_request_id = request.http_request.headers.get(self.request_id_header_name)
+
+ # if response.http_response.headers[self.request_id_header_name] != client_request_id:
+ # raise AzureError(
+ # "Echoed client request ID: {} does not match sent client request ID: {}. "
+ # "Service request ID: {}".format(
+ # response.http_response.headers[self.request_id_header_name], client_request_id,
+ # response.http_response.headers['x-ms-request-id']),
+ # response=response.http_response
+ # )
+
+
+class StorageHosts(SansIOHTTPPolicy):
+
+ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument
+ self.hosts = hosts
+ super(StorageHosts, self).__init__()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ request.context.options['hosts'] = self.hosts
+ parsed_url = urlparse(request.http_request.url)
+
+ # Detect what location mode we're currently requesting with
+ location_mode = LocationMode.PRIMARY
+ for key, value in self.hosts.items():
+ if parsed_url.netloc == value:
+ location_mode = key
+
+ # See if a specific location mode has been specified, and if so, redirect
+ use_location = request.context.options.pop('use_location', None)
+ if use_location:
+ # Lock retries to the specific location
+ request.context.options['retry_to_secondary'] = False
+ if use_location not in self.hosts:
+ raise ValueError(f"Attempting to use undefined host location {use_location}")
+ if use_location != location_mode:
+ # Update request URL to use the specified location
+ updated = parsed_url._replace(netloc=self.hosts[use_location])
+ request.http_request.url = updated.geturl()
+ location_mode = use_location
+
+ request.context.options['location_mode'] = location_mode
+
+
+class StorageLoggingPolicy(NetworkTraceLoggingPolicy):
+ """A policy that logs HTTP request and response to the DEBUG logger.
+
+ This accepts both global configuration, and per-request level with "enable_http_logger"
+ """
+
+ def __init__(self, logging_enable: bool = False, **kwargs) -> None:
+ self.logging_body = kwargs.pop("logging_body", False)
+ super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs)
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ http_request = request.http_request
+ options = request.context.options
+ self.logging_body = self.logging_body or options.pop("logging_body", False)
+ if options.pop("logging_enable", self.enable_http_logger):
+ request.context["logging_enable"] = True
+ if not _LOGGER.isEnabledFor(logging.DEBUG):
+ return
+
+ try:
+ log_url = http_request.url
+ query_params = http_request.query
+ if 'sig' in query_params:
+ log_url = log_url.replace(query_params['sig'], "sig=*****")
+ _LOGGER.debug("Request URL: %r", log_url)
+ _LOGGER.debug("Request method: %r", http_request.method)
+ _LOGGER.debug("Request headers:")
+ for header, value in http_request.headers.items():
+ if header.lower() == 'authorization':
+ value = '*****'
+ elif header.lower() == 'x-ms-copy-source' and 'sig' in value:
+ # take the url apart and scrub away the signed signature
+ scheme, netloc, path, params, query, fragment = urlparse(value)
+ parsed_qs = dict(parse_qsl(query))
+ parsed_qs['sig'] = '*****'
+
+ # the SAS needs to be put back together
+ value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment))
+
+ _LOGGER.debug(" %r: %r", header, value)
+ _LOGGER.debug("Request body:")
+
+ if self.logging_body:
+ _LOGGER.debug(str(http_request.body))
+ else:
+ # We don't want to log the binary data of a file upload.
+ _LOGGER.debug("Hidden body, please use logging_body to show body")
+ except Exception as err: # pylint: disable=broad-except
+ _LOGGER.debug("Failed to log request: %r", err)
+
+ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+ if response.context.pop("logging_enable", self.enable_http_logger):
+ if not _LOGGER.isEnabledFor(logging.DEBUG):
+ return
+
+ try:
+ _LOGGER.debug("Response status: %r", response.http_response.status_code)
+ _LOGGER.debug("Response headers:")
+ for res_header, value in response.http_response.headers.items():
+ _LOGGER.debug(" %r: %r", res_header, value)
+
+ # We don't want to log binary data if the response is a file.
+ _LOGGER.debug("Response content:")
+ pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
+ header = response.http_response.headers.get('content-disposition')
+ resp_content_type = response.http_response.headers.get("content-type", "")
+
+ if header and pattern.match(header):
+ filename = header.partition('=')[2]
+ _LOGGER.debug("File attachments: %s", filename)
+ elif resp_content_type.endswith("octet-stream"):
+ _LOGGER.debug("Body contains binary data.")
+ elif resp_content_type.startswith("image"):
+ _LOGGER.debug("Body contains image data.")
+
+ if self.logging_body and resp_content_type.startswith("text"):
+ _LOGGER.debug(response.http_response.text())
+ elif self.logging_body:
+ try:
+ _LOGGER.debug(response.http_response.body())
+ except ValueError:
+ _LOGGER.debug("Body is streamable")
+
+ except Exception as err: # pylint: disable=broad-except
+ _LOGGER.debug("Failed to log response: %s", repr(err))
+
+
+class StorageRequestHook(SansIOHTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._request_callback = kwargs.get('raw_request_hook')
+ super(StorageRequestHook, self).__init__()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ request_callback = request.context.options.pop('raw_request_hook', self._request_callback)
+ if request_callback:
+ request_callback(request)
+
+
+class StorageResponseHook(HTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._response_callback = kwargs.get('raw_response_hook')
+ super(StorageResponseHook, self).__init__()
+
+ def send(self, request: "PipelineRequest") -> "PipelineResponse":
+ # Values could be 0
+ data_stream_total = request.context.get('data_stream_total')
+ if data_stream_total is None:
+ data_stream_total = request.context.options.pop('data_stream_total', None)
+ download_stream_current = request.context.get('download_stream_current')
+ if download_stream_current is None:
+ download_stream_current = request.context.options.pop('download_stream_current', None)
+ upload_stream_current = request.context.get('upload_stream_current')
+ if upload_stream_current is None:
+ upload_stream_current = request.context.options.pop('upload_stream_current', None)
+
+ response_callback = request.context.get('response_callback') or \
+ request.context.options.pop('raw_response_hook', self._response_callback)
+
+ response = self.next.send(request)
+
+ will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response)
+ # Auth error could come from Bearer challenge, in which case this request will be made again
+ is_auth_error = response.http_response.status_code == 401
+ should_update_counts = not (will_retry or is_auth_error)
+
+ if should_update_counts and download_stream_current is not None:
+ download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
+ if data_stream_total is None:
+ content_range = response.http_response.headers.get('Content-Range')
+ if content_range:
+ data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
+ else:
+ data_stream_total = download_stream_current
+ elif should_update_counts and upload_stream_current is not None:
+ upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
+ for pipeline_obj in [request, response]:
+ if hasattr(pipeline_obj, 'context'):
+ pipeline_obj.context['data_stream_total'] = data_stream_total
+ pipeline_obj.context['download_stream_current'] = download_stream_current
+ pipeline_obj.context['upload_stream_current'] = upload_stream_current
+ if response_callback:
+ response_callback(response)
+ request.context['response_callback'] = response_callback
+ return response
+
+
+class StorageContentValidation(SansIOHTTPPolicy):
+ """A simple policy that sends the given headers
+ with the request.
+
+ This will overwrite any headers already defined in the request.
+ """
+ header_name = 'Content-MD5'
+
+ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument
+ super(StorageContentValidation, self).__init__()
+
+ @staticmethod
+ def get_content_md5(data):
+ # Since HTTP does not differentiate between no content and empty content,
+ # we have to perform a None check.
+ data = data or b""
+ md5 = hashlib.md5() # nosec
+ if isinstance(data, bytes):
+ md5.update(data)
+ elif hasattr(data, 'read'):
+ pos = 0
+ try:
+ pos = data.tell()
+ except: # pylint: disable=bare-except
+ pass
+ for chunk in iter(lambda: data.read(4096), b""):
+ md5.update(chunk)
+ try:
+ data.seek(pos, SEEK_SET)
+ except (AttributeError, IOError) as exc:
+ raise ValueError("Data should be bytes or a seekable file-like object.") from exc
+ else:
+ raise ValueError("Data should be bytes or a seekable file-like object.")
+
+ return md5.digest()
+
+ def on_request(self, request: "PipelineRequest") -> None:
+ validate_content = request.context.options.pop('validate_content', False)
+ if validate_content and request.http_request.method != 'GET':
+ computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data))
+ request.http_request.headers[self.header_name] = computed_md5
+ request.context['validate_content_md5'] = computed_md5
+ request.context['validate_content'] = validate_content
+
+ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ computed_md5 = request.context.get('validate_content_md5') or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ raise AzureError((
+ f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', "
+ f"computed value is '{computed_md5}'."),
+ response=response.http_response
+ )
+
+
+class StorageRetryPolicy(HTTPPolicy):
+ """
+ The base class for Exponential and Linear retries containing shared code.
+ """
+
+ total_retries: int
+ """The max number of retries."""
+ connect_retries: int
+ """The max number of connect retries."""
+ retry_read: int
+ """The max number of read retries."""
+ retry_status: int
+ """The max number of status retries."""
+ retry_to_secondary: bool
+ """Whether the secondary endpoint should be retried."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ self.total_retries = kwargs.pop('retry_total', 10)
+ self.connect_retries = kwargs.pop('retry_connect', 3)
+ self.read_retries = kwargs.pop('retry_read', 3)
+ self.status_retries = kwargs.pop('retry_status', 3)
+ self.retry_to_secondary = kwargs.pop('retry_to_secondary', False)
+ super(StorageRetryPolicy, self).__init__()
+
+ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None:
+ """
+ A function which sets the next host location on the request, if applicable.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the next host location.
+ :param PipelineRequest request: A pipeline request object.
+ """
+ if settings['hosts'] and all(settings['hosts'].values()):
+ url = urlparse(request.url)
+ # If there's more than one possible location, retry to the alternative
+ if settings['mode'] == LocationMode.PRIMARY:
+ settings['mode'] = LocationMode.SECONDARY
+ else:
+ settings['mode'] = LocationMode.PRIMARY
+ updated = url._replace(netloc=settings['hosts'].get(settings['mode']))
+ request.url = updated.geturl()
+
+ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]:
+ body_position = None
+ if hasattr(request.http_request.body, 'read'):
+ try:
+ body_position = request.http_request.body.tell()
+ except (AttributeError, UnsupportedOperation):
+ # if body position cannot be obtained, then retries will not work
+ pass
+ options = request.context.options
+ return {
+ 'total': options.pop("retry_total", self.total_retries),
+ 'connect': options.pop("retry_connect", self.connect_retries),
+ 'read': options.pop("retry_read", self.read_retries),
+ 'status': options.pop("retry_status", self.status_retries),
+ 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary),
+ 'mode': options.pop("location_mode", LocationMode.PRIMARY),
+ 'hosts': options.pop("hosts", None),
+ 'hook': options.pop("retry_hook", None),
+ 'body_position': body_position,
+ 'count': 0,
+ 'history': []
+ }
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument
+ """ Formula for computing the current backoff.
+ Should be calculated by child class.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
+ :returns: The backoff time.
+ :rtype: float
+ """
+ return 0
+
+ def sleep(self, settings, transport):
+ backoff = self.get_backoff_time(settings)
+ if not backoff or backoff < 0:
+ return
+ transport.sleep(backoff)
+
+ def increment(
+ self, settings: Dict[str, Any],
+ request: "PipelineRequest",
+ response: Optional["PipelineResponse"] = None,
+ error: Optional[AzureError] = None
+ ) -> bool:
+ """Increment the retry counters.
+
+ :param Dict[str, Any] settings: The configurable values pertaining to the increment operation.
+ :param PipelineRequest request: A pipeline request object.
+ :param Optional[PipelineResponse] response: A pipeline response object.
+ :param Optional[AzureError] error: An error encountered during the request, or
+ None if the response was received successfully.
+ :returns: Whether the retry attempts are exhausted.
+ :rtype: bool
+ """
+ settings['total'] -= 1
+
+ if error and isinstance(error, ServiceRequestError):
+ # Errors when we're fairly sure that the server did not receive the
+ # request, so it should be safe to retry.
+ settings['connect'] -= 1
+ settings['history'].append(RequestHistory(request, error=error))
+
+ elif error and isinstance(error, ServiceResponseError):
+ # Errors that occur after the request has been started, so we should
+ # assume that the server began processing it.
+ settings['read'] -= 1
+ settings['history'].append(RequestHistory(request, error=error))
+
+ else:
+ # Incrementing because of a server error like a 500 in
+ # status_forcelist and a the given method is in the allowlist
+ if response:
+ settings['status'] -= 1
+ settings['history'].append(RequestHistory(request, http_response=response))
+
+ if not is_exhausted(settings):
+ if request.method not in ['PUT'] and settings['retry_secondary']:
+ self._set_next_host_location(settings, request)
+
+ # rewind the request body if it is a stream
+ if request.body and hasattr(request.body, 'read'):
+ # no position was saved, then retry would not work
+ if settings['body_position'] is None:
+ return False
+ try:
+ # attempt to rewind the body to the initial position
+ request.body.seek(settings['body_position'], SEEK_SET)
+ except (UnsupportedOperation, ValueError):
+ # if body is not seekable, then retry would not work
+ return False
+ settings['count'] += 1
+ return True
+ return False
+
+ def send(self, request):
+ retries_remaining = True
+ response = None
+ retry_settings = self.configure_retries(request)
+ while retries_remaining:
+ try:
+ response = self.next.send(request)
+ if is_retry(response, retry_settings['mode']) or is_checksum_retry(response):
+ retries_remaining = self.increment(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response)
+ if retries_remaining:
+ retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response,
+ error=None)
+ self.sleep(retry_settings, request.context.transport)
+ continue
+ break
+ except AzureError as err:
+ if isinstance(err, AzureSigningError):
+ raise
+ retries_remaining = self.increment(
+ retry_settings, request=request.http_request, error=err)
+ if retries_remaining:
+ retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=None,
+ error=err)
+ self.sleep(retry_settings, request.context.transport)
+ continue
+ raise err
+ if retry_settings['history']:
+ response.context['history'] = retry_settings['history']
+ response.http_response.location_mode = retry_settings['mode']
+ return response
+
+
+class ExponentialRetry(StorageRetryPolicy):
+ """Exponential retry."""
+
+ initial_backoff: int
+ """The initial backoff interval, in seconds, for the first retry."""
+ increment_base: int
+ """The base, in seconds, to increment the initial_backoff by after the
+ first retry."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, initial_backoff: int = 15,
+ increment_base: int = 3,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs an Exponential retry object. The initial_backoff is used for
+ the first retry. Subsequent retries are retried after initial_backoff +
+ increment_power^retry_count seconds.
+
+ :param int initial_backoff:
+ The initial backoff interval, in seconds, for the first retry.
+ :param int increment_base:
+ The base, in seconds, to increment the initial_backoff by after the
+ first retry.
+ :param int retry_total:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.initial_backoff = initial_backoff
+ self.increment_base = increment_base
+ self.random_jitter_range = random_jitter_range
+ super(ExponentialRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time.
+ :returns:
+ A float indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: float
+ """
+ random_generator = random.Random()
+ backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
+ random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
+ random_range_end = backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class LinearRetry(StorageRetryPolicy):
+ """Linear retry."""
+
+ initial_backoff: int
+ """The backoff interval, in seconds, between retries."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, backoff: int = 15,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs a Linear retry object.
+
+ :param int backoff:
+ The backoff interval, in seconds, between retries.
+ :param int retry_total:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.backoff = backoff
+ self.random_jitter_range = random_jitter_range
+ super(LinearRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
+ :returns:
+ A float indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: float
+ """
+ random_generator = random.Random()
+ # the backoff interval normally does not change, however there is the possibility
+ # that it was modified by accessing the property directly after initializing the object
+ random_range_start = self.backoff - self.random_jitter_range \
+ if self.backoff > self.random_jitter_range else 0
+ random_range_end = self.backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
+ """ Custom Bearer token credential policy for following Storage Bearer challenges """
+
+ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None:
+ super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
+
+ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
+ try:
+ auth_header = response.http_response.headers.get("WWW-Authenticate")
+ challenge = StorageHttpChallenge(auth_header)
+ except ValueError:
+ return False
+
+ scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
+ self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
+
+ return True
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies_async.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies_async.py
new file mode 100644
index 00000000..86a4b4c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/policies_async.py
@@ -0,0 +1,296 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=invalid-overridden-method
+
+import asyncio
+import logging
+import random
+from typing import Any, Dict, TYPE_CHECKING
+
+from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError
+from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy
+
+from .authentication import AzureSigningError, StorageHttpChallenge
+from .constants import DEFAULT_OAUTH_SCOPE
+from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy
+
+if TYPE_CHECKING:
+ from azure.core.credentials_async import AsyncTokenCredential
+ from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
+ PipelineRequest,
+ PipelineResponse
+ )
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+async def retry_hook(settings, **kwargs):
+ if settings['hook']:
+ if asyncio.iscoroutine(settings['hook']):
+ await settings['hook'](
+ retry_count=settings['count'] - 1,
+ location_mode=settings['mode'],
+ **kwargs)
+ else:
+ settings['hook'](
+ retry_count=settings['count'] - 1,
+ location_mode=settings['mode'],
+ **kwargs)
+
+
+async def is_checksum_retry(response):
+ # retry if invalid content md5
+ if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+ try:
+ await response.http_response.load_body() # Load the body in memory and close the socket
+ except (StreamClosedError, StreamConsumedError):
+ pass
+ computed_md5 = response.http_request.headers.get('content-md5', None) or \
+ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+ if response.http_response.headers['content-md5'] != computed_md5:
+ return True
+ return False
+
+
+class AsyncStorageResponseHook(AsyncHTTPPolicy):
+
+ def __init__(self, **kwargs):
+ self._response_callback = kwargs.get('raw_response_hook')
+ super(AsyncStorageResponseHook, self).__init__()
+
+ async def send(self, request: "PipelineRequest") -> "PipelineResponse":
+ # Values could be 0
+ data_stream_total = request.context.get('data_stream_total')
+ if data_stream_total is None:
+ data_stream_total = request.context.options.pop('data_stream_total', None)
+ download_stream_current = request.context.get('download_stream_current')
+ if download_stream_current is None:
+ download_stream_current = request.context.options.pop('download_stream_current', None)
+ upload_stream_current = request.context.get('upload_stream_current')
+ if upload_stream_current is None:
+ upload_stream_current = request.context.options.pop('upload_stream_current', None)
+
+ response_callback = request.context.get('response_callback') or \
+ request.context.options.pop('raw_response_hook', self._response_callback)
+
+ response = await self.next.send(request)
+ will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response)
+
+ # Auth error could come from Bearer challenge, in which case this request will be made again
+ is_auth_error = response.http_response.status_code == 401
+ should_update_counts = not (will_retry or is_auth_error)
+
+ if should_update_counts and download_stream_current is not None:
+ download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
+ if data_stream_total is None:
+ content_range = response.http_response.headers.get('Content-Range')
+ if content_range:
+ data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
+ else:
+ data_stream_total = download_stream_current
+ elif should_update_counts and upload_stream_current is not None:
+ upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
+ for pipeline_obj in [request, response]:
+ if hasattr(pipeline_obj, 'context'):
+ pipeline_obj.context['data_stream_total'] = data_stream_total
+ pipeline_obj.context['download_stream_current'] = download_stream_current
+ pipeline_obj.context['upload_stream_current'] = upload_stream_current
+ if response_callback:
+ if asyncio.iscoroutine(response_callback):
+ await response_callback(response) # type: ignore
+ else:
+ response_callback(response)
+ request.context['response_callback'] = response_callback
+ return response
+
+class AsyncStorageRetryPolicy(StorageRetryPolicy):
+ """
+ The base class for Exponential and Linear retries containing shared code.
+ """
+
+ async def sleep(self, settings, transport):
+ backoff = self.get_backoff_time(settings)
+ if not backoff or backoff < 0:
+ return
+ await transport.sleep(backoff)
+
+ async def send(self, request):
+ retries_remaining = True
+ response = None
+ retry_settings = self.configure_retries(request)
+ while retries_remaining:
+ try:
+ response = await self.next.send(request)
+ if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response):
+ retries_remaining = self.increment(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response)
+ if retries_remaining:
+ await retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=response.http_response,
+ error=None)
+ await self.sleep(retry_settings, request.context.transport)
+ continue
+ break
+ except AzureError as err:
+ if isinstance(err, AzureSigningError):
+ raise
+ retries_remaining = self.increment(
+ retry_settings, request=request.http_request, error=err)
+ if retries_remaining:
+ await retry_hook(
+ retry_settings,
+ request=request.http_request,
+ response=None,
+ error=err)
+ await self.sleep(retry_settings, request.context.transport)
+ continue
+ raise err
+ if retry_settings['history']:
+ response.context['history'] = retry_settings['history']
+ response.http_response.location_mode = retry_settings['mode']
+ return response
+
+
+class ExponentialRetry(AsyncStorageRetryPolicy):
+ """Exponential retry."""
+
+ initial_backoff: int
+ """The initial backoff interval, in seconds, for the first retry."""
+ increment_base: int
+ """The base, in seconds, to increment the initial_backoff by after the
+ first retry."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self,
+ initial_backoff: int = 15,
+ increment_base: int = 3,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3, **kwargs
+ ) -> None:
+ """
+ Constructs an Exponential retry object. The initial_backoff is used for
+ the first retry. Subsequent retries are retried after initial_backoff +
+ increment_power^retry_count seconds. For example, by default the first retry
+ occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the
+ third after (15+3^2) = 24 seconds.
+
+ :param int initial_backoff:
+ The initial backoff interval, in seconds, for the first retry.
+ :param int increment_base:
+ The base, in seconds, to increment the initial_backoff by after the
+ first retry.
+ :param int max_attempts:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.initial_backoff = initial_backoff
+ self.increment_base = increment_base
+ self.random_jitter_range = random_jitter_range
+ super(ExponentialRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
+ :return:
+ An integer indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: int or None
+ """
+ random_generator = random.Random()
+ backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
+ random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
+ random_range_end = backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class LinearRetry(AsyncStorageRetryPolicy):
+ """Linear retry."""
+
+ initial_backoff: int
+ """The backoff interval, in seconds, between retries."""
+ random_jitter_range: int
+ """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+ def __init__(
+ self, backoff: int = 15,
+ retry_total: int = 3,
+ retry_to_secondary: bool = False,
+ random_jitter_range: int = 3,
+ **kwargs: Any
+ ) -> None:
+ """
+ Constructs a Linear retry object.
+
+ :param int backoff:
+ The backoff interval, in seconds, between retries.
+ :param int max_attempts:
+ The maximum number of retry attempts.
+ :param bool retry_to_secondary:
+ Whether the request should be retried to secondary, if able. This should
+ only be enabled of RA-GRS accounts are used and potentially stale data
+ can be handled.
+ :param int random_jitter_range:
+ A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+ For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+ """
+ self.backoff = backoff
+ self.random_jitter_range = random_jitter_range
+ super(LinearRetry, self).__init__(
+ retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+ def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+ """
+ Calculates how long to sleep before retrying.
+
+ :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
+ :return:
+ An integer indicating how long to wait before retrying the request,
+ or None to indicate no retry should be performed.
+ :rtype: int or None
+ """
+ random_generator = random.Random()
+ # the backoff interval normally does not change, however there is the possibility
+ # that it was modified by accessing the property directly after initializing the object
+ random_range_start = self.backoff - self.random_jitter_range \
+ if self.backoff > self.random_jitter_range else 0
+ random_range_end = self.backoff + self.random_jitter_range
+ return random_generator.uniform(random_range_start, random_range_end)
+
+
+class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
+ """ Custom Bearer token credential policy for following Storage Bearer challenges """
+
+ def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None:
+ super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
+
+ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
+ try:
+ auth_header = response.http_response.headers.get("WWW-Authenticate")
+ challenge = StorageHttpChallenge(auth_header)
+ except ValueError:
+ return False
+
+ scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
+ await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
+
+ return True
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/request_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/request_handlers.py
new file mode 100644
index 00000000..54927cc7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/request_handlers.py
@@ -0,0 +1,270 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import logging
+import stat
+from io import (SEEK_END, SEEK_SET, UnsupportedOperation)
+from os import fstat
+from typing import Dict, Optional
+
+import isodate
+
+
+_LOGGER = logging.getLogger(__name__)
+
+_REQUEST_DELIMITER_PREFIX = "batch_"
+_HTTP1_1_IDENTIFIER = "HTTP/1.1"
+_HTTP_LINE_ENDING = "\r\n"
+
+
+def serialize_iso(attr):
+ """Serialize Datetime object into ISO-8601 formatted string.
+
+ :param Datetime attr: Object to be serialized.
+ :rtype: str
+ :raises: ValueError if format invalid.
+ """
+ if not attr:
+ return None
+ if isinstance(attr, str):
+ attr = isodate.parse_datetime(attr)
+ try:
+ utc = attr.utctimetuple()
+ if utc.tm_year > 9999 or utc.tm_year < 1:
+ raise OverflowError("Hit max or min date")
+
+ date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}"
+ return date + 'Z'
+ except (ValueError, OverflowError) as err:
+ raise ValueError("Unable to serialize datetime object.") from err
+ except AttributeError as err:
+ raise TypeError("ISO-8601 object must be valid datetime object.") from err
+
+def get_length(data):
+ length = None
+ # Check if object implements the __len__ method, covers most input cases such as bytearray.
+ try:
+ length = len(data)
+ except: # pylint: disable=bare-except
+ pass
+
+ if not length:
+ # Check if the stream is a file-like stream object.
+ # If so, calculate the size using the file descriptor.
+ try:
+ fileno = data.fileno()
+ except (AttributeError, UnsupportedOperation):
+ pass
+ else:
+ try:
+ mode = fstat(fileno).st_mode
+ if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
+ #st_size only meaningful if regular file or symlink, other types
+ # e.g. sockets may return misleading sizes like 0
+ return fstat(fileno).st_size
+ except OSError:
+ # Not a valid fileno, may be possible requests returned
+ # a socket number?
+ pass
+
+ # If the stream is seekable and tell() is implemented, calculate the stream size.
+ try:
+ current_position = data.tell()
+ data.seek(0, SEEK_END)
+ length = data.tell() - current_position
+ data.seek(current_position, SEEK_SET)
+ except (AttributeError, OSError, UnsupportedOperation):
+ pass
+
+ return length
+
+
+def read_length(data):
+ try:
+ if hasattr(data, 'read'):
+ read_data = b''
+ for chunk in iter(lambda: data.read(4096), b""):
+ read_data += chunk
+ return len(read_data), read_data
+ if hasattr(data, '__iter__'):
+ read_data = b''
+ for chunk in data:
+ read_data += chunk
+ return len(read_data), read_data
+ except: # pylint: disable=bare-except
+ pass
+ raise ValueError("Unable to calculate content length, please specify.")
+
+
+def validate_and_format_range_headers(
+ start_range, end_range, start_range_required=True,
+ end_range_required=True, check_content_md5=False, align_to_page=False):
+ # If end range is provided, start range must be provided
+ if (start_range_required or end_range is not None) and start_range is None:
+ raise ValueError("start_range value cannot be None.")
+ if end_range_required and end_range is None:
+ raise ValueError("end_range value cannot be None.")
+
+ # Page ranges must be 512 aligned
+ if align_to_page:
+ if start_range is not None and start_range % 512 != 0:
+ raise ValueError(f"Invalid page blob start_range: {start_range}. "
+ "The size must be aligned to a 512-byte boundary.")
+ if end_range is not None and end_range % 512 != 511:
+ raise ValueError(f"Invalid page blob end_range: {end_range}. "
+ "The size must be aligned to a 512-byte boundary.")
+
+ # Format based on whether end_range is present
+ range_header = None
+ if end_range is not None:
+ range_header = f'bytes={start_range}-{end_range}'
+ elif start_range is not None:
+ range_header = f"bytes={start_range}-"
+
+ # Content MD5 can only be provided for a complete range less than 4MB in size
+ range_validation = None
+ if check_content_md5:
+ if start_range is None or end_range is None:
+ raise ValueError("Both start and end range required for MD5 content validation.")
+ if end_range - start_range > 4 * 1024 * 1024:
+ raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.")
+ range_validation = 'true'
+
+ return range_header, range_validation
+
+
+def add_metadata_headers(metadata=None):
+ # type: (Optional[Dict[str, str]]) -> Dict[str, str]
+ headers = {}
+ if metadata:
+ for key, value in metadata.items():
+ headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value
+ return headers
+
+
+def serialize_batch_body(requests, batch_id):
+ """
+ --<delimiter>
+ <subrequest>
+ --<delimiter>
+ <subrequest> (repeated as needed)
+ --<delimiter>--
+
+ Serializes the requests in this batch to a single HTTP mixed/multipart body.
+
+ :param List[~azure.core.pipeline.transport.HttpRequest] requests:
+ a list of sub-request for the batch request
+ :param str batch_id:
+ to be embedded in batch sub-request delimiter
+ :returns: The body bytes for this batch.
+ :rtype: bytes
+ """
+
+ if requests is None or len(requests) == 0:
+ raise ValueError('Please provide sub-request(s) for this batch request')
+
+ delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8')
+ newline_bytes = _HTTP_LINE_ENDING.encode('utf-8')
+ batch_body = []
+
+ content_index = 0
+ for request in requests:
+ request.headers.update({
+ "Content-ID": str(content_index),
+ "Content-Length": str(0)
+ })
+ batch_body.append(delimiter_bytes)
+ batch_body.append(_make_body_from_sub_request(request))
+ batch_body.append(newline_bytes)
+ content_index += 1
+
+ batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8'))
+ # final line of body MUST have \r\n at the end, or it will not be properly read by the service
+ batch_body.append(newline_bytes)
+
+ return b"".join(batch_body)
+
+
+def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_dashes=False):
+ """
+ Gets the delimiter used for this batch request's mixed/multipart HTTP format.
+
+ :param str batch_id:
+ Randomly generated id
+ :param bool is_prepend_dashes:
+ Whether to include the starting dashes. Used in the body, but non on defining the delimiter.
+ :param bool is_append_dashes:
+ Whether to include the ending dashes. Used in the body on the closing delimiter only.
+ :returns: The delimiter, WITHOUT a trailing newline.
+ :rtype: str
+ """
+
+ prepend_dashes = '--' if is_prepend_dashes else ''
+ append_dashes = '--' if is_append_dashes else ''
+
+ return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes
+
+
+def _make_body_from_sub_request(sub_request):
+ """
+ Content-Type: application/http
+ Content-ID: <sequential int ID>
+ Content-Transfer-Encoding: <value> (if present)
+
+ <verb> <path><query> HTTP/<version>
+ <header key>: <header value> (repeated as necessary)
+ Content-Length: <value>
+ (newline if content length > 0)
+ <body> (if content length > 0)
+
+ Serializes an http request.
+
+ :param ~azure.core.pipeline.transport.HttpRequest sub_request:
+ Request to serialize.
+ :returns: The serialized sub-request in bytes
+ :rtype: bytes
+ """
+
+ # put the sub-request's headers into a list for efficient str concatenation
+ sub_request_body = []
+
+ # get headers for ease of manipulation; remove headers as they are used
+ headers = sub_request.headers
+
+ # append opening headers
+ sub_request_body.append("Content-Type: application/http")
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ sub_request_body.append("Content-ID: ")
+ sub_request_body.append(headers.pop("Content-ID", ""))
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ sub_request_body.append("Content-Transfer-Encoding: binary")
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append blank line
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append HTTP verb and path and query and HTTP version
+ sub_request_body.append(sub_request.method)
+ sub_request_body.append(' ')
+ sub_request_body.append(sub_request.url)
+ sub_request_body.append(' ')
+ sub_request_body.append(_HTTP1_1_IDENTIFIER)
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append remaining headers (this will set the Content-Length, as it was set on `sub-request`)
+ for header_name, header_value in headers.items():
+ if header_value is not None:
+ sub_request_body.append(header_name)
+ sub_request_body.append(": ")
+ sub_request_body.append(header_value)
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ # append blank line
+ sub_request_body.append(_HTTP_LINE_ENDING)
+
+ return ''.join(sub_request_body).encode()
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/response_handlers.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/response_handlers.py
new file mode 100644
index 00000000..af9a2fcd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/response_handlers.py
@@ -0,0 +1,200 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import logging
+from typing import NoReturn
+from xml.etree.ElementTree import Element
+
+from azure.core.exceptions import (
+ ClientAuthenticationError,
+ DecodeError,
+ HttpResponseError,
+ ResourceExistsError,
+ ResourceModifiedError,
+ ResourceNotFoundError,
+)
+from azure.core.pipeline.policies import ContentDecodePolicy
+
+from .authentication import AzureSigningError
+from .models import get_enum_value, StorageErrorCode, UserDelegationKey
+from .parser import _to_utc_datetime
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PartialBatchErrorException(HttpResponseError):
+ """There is a partial failure in batch operations.
+
+ :param str message: The message of the exception.
+ :param response: Server response to be deserialized.
+ :param list parts: A list of the parts in multipart response.
+ """
+
+ def __init__(self, message, response, parts):
+ self.parts = parts
+ super(PartialBatchErrorException, self).__init__(message=message, response=response)
+
+
+# Parses the blob length from the content range header: bytes 1-3/65537
+def parse_length_from_content_range(content_range):
+ if content_range is None:
+ return None
+
+ # First, split in space and take the second half: '1-3/65537'
+ # Next, split on slash and take the second half: '65537'
+ # Finally, convert to an int: 65537
+ return int(content_range.split(' ', 1)[1].split('/', 1)[1])
+
+
+def normalize_headers(headers):
+ normalized = {}
+ for key, value in headers.items():
+ if key.startswith('x-ms-'):
+ key = key[5:]
+ normalized[key.lower().replace('-', '_')] = get_enum_value(value)
+ return normalized
+
+
+def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument
+ try:
+ raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')}
+ except AttributeError:
+ raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')}
+ return {k[10:]: v for k, v in raw_metadata.items()}
+
+
+def return_response_headers(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return normalize_headers(response_headers)
+
+
+def return_headers_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return normalize_headers(response_headers), deserialized
+
+
+def return_context_and_deserialized(response, deserialized, response_headers): # pylint: disable=unused-argument
+ return response.http_response.location_mode, deserialized
+
+
+def return_raw_deserialized(response, *_):
+ return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME]
+
+
+def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches
+ raise_error = HttpResponseError
+ serialized = False
+ if isinstance(storage_error, AzureSigningError):
+ storage_error.message = storage_error.message + \
+ '. This is likely due to an invalid shared key. Please check your shared key and try again.'
+ if not storage_error.response or storage_error.response.status_code in [200, 204]:
+ raise storage_error
+ # If it is one of those three then it has been serialized prior by the generated layer.
+ if isinstance(storage_error, (PartialBatchErrorException,
+ ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)):
+ serialized = True
+ error_code = storage_error.response.headers.get('x-ms-error-code')
+ error_message = storage_error.message
+ additional_data = {}
+ error_dict = {}
+ try:
+ error_body = ContentDecodePolicy.deserialize_from_http_generics(storage_error.response)
+ try:
+ if error_body is None or len(error_body) == 0:
+ error_body = storage_error.response.reason
+ except AttributeError:
+ error_body = ''
+ # If it is an XML response
+ if isinstance(error_body, Element):
+ error_dict = {
+ child.tag.lower(): child.text
+ for child in error_body
+ }
+ # If it is a JSON response
+ elif isinstance(error_body, dict):
+ error_dict = error_body.get('error', {})
+ elif not error_code:
+ _LOGGER.warning(
+ 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body))
+ error_dict = {'message': str(error_body)}
+
+ # If we extracted from a Json or XML response
+ # There is a chance error_dict is just a string
+ if error_dict and isinstance(error_dict, dict):
+ error_code = error_dict.get('code')
+ error_message = error_dict.get('message')
+ additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}}
+ except DecodeError:
+ pass
+
+ try:
+ # This check would be unnecessary if we have already serialized the error
+ if error_code and not serialized:
+ error_code = StorageErrorCode(error_code)
+ if error_code in [StorageErrorCode.condition_not_met,
+ StorageErrorCode.blob_overwritten]:
+ raise_error = ResourceModifiedError
+ if error_code in [StorageErrorCode.invalid_authentication_info,
+ StorageErrorCode.authentication_failed]:
+ raise_error = ClientAuthenticationError
+ if error_code in [StorageErrorCode.resource_not_found,
+ StorageErrorCode.cannot_verify_copy_source,
+ StorageErrorCode.blob_not_found,
+ StorageErrorCode.queue_not_found,
+ StorageErrorCode.container_not_found,
+ StorageErrorCode.parent_not_found,
+ StorageErrorCode.share_not_found]:
+ raise_error = ResourceNotFoundError
+ if error_code in [StorageErrorCode.account_already_exists,
+ StorageErrorCode.account_being_created,
+ StorageErrorCode.resource_already_exists,
+ StorageErrorCode.resource_type_mismatch,
+ StorageErrorCode.blob_already_exists,
+ StorageErrorCode.queue_already_exists,
+ StorageErrorCode.container_already_exists,
+ StorageErrorCode.container_being_deleted,
+ StorageErrorCode.queue_being_deleted,
+ StorageErrorCode.share_already_exists,
+ StorageErrorCode.share_being_deleted]:
+ raise_error = ResourceExistsError
+ except ValueError:
+ # Got an unknown error code
+ pass
+
+ # Error message should include all the error properties
+ try:
+ error_message += f"\nErrorCode:{error_code.value}"
+ except AttributeError:
+ error_message += f"\nErrorCode:{error_code}"
+ for name, info in additional_data.items():
+ error_message += f"\n{name}:{info}"
+
+ # No need to create an instance if it has already been serialized by the generated layer
+ if serialized:
+ storage_error.message = error_message
+ error = storage_error
+ else:
+ error = raise_error(message=error_message, response=storage_error.response)
+ # Ensure these properties are stored in the error instance as well (not just the error message)
+ error.error_code = error_code
+ error.additional_info = additional_data
+ # error.args is what's surfaced on the traceback - show error message in all cases
+ error.args = (error.message,)
+ try:
+ # `from None` prevents us from double printing the exception (suppresses generated layer error context)
+ exec("raise error from None") # pylint: disable=exec-used # nosec
+ except SyntaxError as exc:
+ raise error from exc
+
+
+def parse_to_internal_user_delegation_key(service_user_delegation_key):
+ internal_user_delegation_key = UserDelegationKey()
+ internal_user_delegation_key.signed_oid = service_user_delegation_key.signed_oid
+ internal_user_delegation_key.signed_tid = service_user_delegation_key.signed_tid
+ internal_user_delegation_key.signed_start = _to_utc_datetime(service_user_delegation_key.signed_start)
+ internal_user_delegation_key.signed_expiry = _to_utc_datetime(service_user_delegation_key.signed_expiry)
+ internal_user_delegation_key.signed_service = service_user_delegation_key.signed_service
+ internal_user_delegation_key.signed_version = service_user_delegation_key.signed_version
+ internal_user_delegation_key.value = service_user_delegation_key.value
+ return internal_user_delegation_key
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/shared_access_signature.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/shared_access_signature.py
new file mode 100644
index 00000000..df29222b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/shared_access_signature.py
@@ -0,0 +1,252 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+# pylint: disable=docstring-keyword-should-match-keyword-only
+
+from datetime import date
+
+from .parser import _to_utc_datetime
+from .constants import X_MS_VERSION
+from . import sign_string, url_quote
+
+# cspell:ignoreRegExp rsc.
+# cspell:ignoreRegExp s..?id
+class QueryStringConstants(object):
+ SIGNED_SIGNATURE = 'sig'
+ SIGNED_PERMISSION = 'sp'
+ SIGNED_START = 'st'
+ SIGNED_EXPIRY = 'se'
+ SIGNED_RESOURCE = 'sr'
+ SIGNED_IDENTIFIER = 'si'
+ SIGNED_IP = 'sip'
+ SIGNED_PROTOCOL = 'spr'
+ SIGNED_VERSION = 'sv'
+ SIGNED_CACHE_CONTROL = 'rscc'
+ SIGNED_CONTENT_DISPOSITION = 'rscd'
+ SIGNED_CONTENT_ENCODING = 'rsce'
+ SIGNED_CONTENT_LANGUAGE = 'rscl'
+ SIGNED_CONTENT_TYPE = 'rsct'
+ START_PK = 'spk'
+ START_RK = 'srk'
+ END_PK = 'epk'
+ END_RK = 'erk'
+ SIGNED_RESOURCE_TYPES = 'srt'
+ SIGNED_SERVICES = 'ss'
+ SIGNED_OID = 'skoid'
+ SIGNED_TID = 'sktid'
+ SIGNED_KEY_START = 'skt'
+ SIGNED_KEY_EXPIRY = 'ske'
+ SIGNED_KEY_SERVICE = 'sks'
+ SIGNED_KEY_VERSION = 'skv'
+ SIGNED_ENCRYPTION_SCOPE = 'ses'
+
+ # for ADLS
+ SIGNED_AUTHORIZED_OID = 'saoid'
+ SIGNED_UNAUTHORIZED_OID = 'suoid'
+ SIGNED_CORRELATION_ID = 'scid'
+ SIGNED_DIRECTORY_DEPTH = 'sdd'
+
+ @staticmethod
+ def to_list():
+ return [
+ QueryStringConstants.SIGNED_SIGNATURE,
+ QueryStringConstants.SIGNED_PERMISSION,
+ QueryStringConstants.SIGNED_START,
+ QueryStringConstants.SIGNED_EXPIRY,
+ QueryStringConstants.SIGNED_RESOURCE,
+ QueryStringConstants.SIGNED_IDENTIFIER,
+ QueryStringConstants.SIGNED_IP,
+ QueryStringConstants.SIGNED_PROTOCOL,
+ QueryStringConstants.SIGNED_VERSION,
+ QueryStringConstants.SIGNED_CACHE_CONTROL,
+ QueryStringConstants.SIGNED_CONTENT_DISPOSITION,
+ QueryStringConstants.SIGNED_CONTENT_ENCODING,
+ QueryStringConstants.SIGNED_CONTENT_LANGUAGE,
+ QueryStringConstants.SIGNED_CONTENT_TYPE,
+ QueryStringConstants.START_PK,
+ QueryStringConstants.START_RK,
+ QueryStringConstants.END_PK,
+ QueryStringConstants.END_RK,
+ QueryStringConstants.SIGNED_RESOURCE_TYPES,
+ QueryStringConstants.SIGNED_SERVICES,
+ QueryStringConstants.SIGNED_OID,
+ QueryStringConstants.SIGNED_TID,
+ QueryStringConstants.SIGNED_KEY_START,
+ QueryStringConstants.SIGNED_KEY_EXPIRY,
+ QueryStringConstants.SIGNED_KEY_SERVICE,
+ QueryStringConstants.SIGNED_KEY_VERSION,
+ QueryStringConstants.SIGNED_ENCRYPTION_SCOPE,
+ # for ADLS
+ QueryStringConstants.SIGNED_AUTHORIZED_OID,
+ QueryStringConstants.SIGNED_UNAUTHORIZED_OID,
+ QueryStringConstants.SIGNED_CORRELATION_ID,
+ QueryStringConstants.SIGNED_DIRECTORY_DEPTH,
+ ]
+
+
+class SharedAccessSignature(object):
+ '''
+ Provides a factory for creating account access
+ signature tokens with an account name and account key. Users can either
+ use the factory or can construct the appropriate service and use the
+ generate_*_shared_access_signature method directly.
+ '''
+
+ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION):
+ '''
+ :param str account_name:
+ The storage account name used to generate the shared access signatures.
+ :param str account_key:
+ The access key to generate the shares access signatures.
+ :param str x_ms_version:
+ The service version used to generate the shared access signatures.
+ '''
+ self.account_name = account_name
+ self.account_key = account_key
+ self.x_ms_version = x_ms_version
+
+ def generate_account(
+ self, services,
+ resource_types,
+ permission,
+ expiry,
+ start=None,
+ ip=None,
+ protocol=None,
+ sts_hook=None,
+ **kwargs
+ ) -> str:
+ '''
+ Generates a shared access signature for the account.
+ Use the returned signature with the sas_token parameter of the service
+ or to create a new account object.
+
+ :param Any services: The specified services associated with the shared access signature.
+ :param ResourceTypes resource_types:
+ Specifies the resource types that are accessible with the account
+ SAS. You can combine values to provide access to more than one
+ resource type.
+ :param AccountSasPermissions permission:
+ The permissions associated with the shared access signature. The
+ user is restricted to operations allowed by the permissions.
+ Required unless an id is given referencing a stored access policy
+ which contains this field. This field must be omitted if it has been
+ specified in an associated stored access policy. You can combine
+ values to provide more than one permission.
+ :param expiry:
+ The time at which the shared access signature becomes invalid.
+ Required unless an id is given referencing a stored access policy
+ which contains this field. This field must be omitted if it has
+ been specified in an associated stored access policy. Azure will always
+ convert values to UTC. If a date is passed in without timezone info, it
+ is assumed to be UTC.
+ :type expiry: datetime or str
+ :param start:
+ The time at which the shared access signature becomes valid. If
+ omitted, start time for this call is assumed to be the time when the
+ storage service receives the request. The provided datetime will always
+ be interpreted as UTC.
+ :type start: datetime or str
+ :param str ip:
+ Specifies an IP address or a range of IP addresses from which to accept requests.
+ If the IP address from which the request originates does not match the IP address
+ or address range specified on the SAS token, the request is not authenticated.
+ For example, specifying sip=168.1.5.65 or sip=168.1.5.60-168.1.5.70 on the SAS
+ restricts the request to those IP addresses.
+ :param str protocol:
+ Specifies the protocol permitted for a request made. The default value
+ is https,http. See :class:`~azure.storage.common.models.Protocol` for possible values.
+ :keyword str encryption_scope:
+ Optional. If specified, this is the encryption scope to use when sending requests
+ authorized with this SAS URI.
+ :param sts_hook:
+ For debugging purposes only. If provided, the hook is called with the string to sign
+ that was used to generate the SAS.
+ :type sts_hook: Optional[Callable[[str], None]]
+ :returns: The generated SAS token for the account.
+ :rtype: str
+ '''
+ sas = _SharedAccessHelper()
+ sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version)
+ sas.add_account(services, resource_types)
+ sas.add_encryption_scope(**kwargs)
+ sas.add_account_signature(self.account_name, self.account_key)
+
+ if sts_hook is not None:
+ sts_hook(sas.string_to_sign)
+
+ return sas.get_token()
+
+
+class _SharedAccessHelper(object):
+ def __init__(self):
+ self.query_dict = {}
+ self.string_to_sign = ""
+
+ def _add_query(self, name, val):
+ if val:
+ self.query_dict[name] = str(val) if val is not None else None
+
+ def add_encryption_scope(self, **kwargs):
+ self._add_query(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, kwargs.pop('encryption_scope', None))
+
+ def add_base(self, permission, expiry, start, ip, protocol, x_ms_version):
+ if isinstance(start, date):
+ start = _to_utc_datetime(start)
+
+ if isinstance(expiry, date):
+ expiry = _to_utc_datetime(expiry)
+
+ self._add_query(QueryStringConstants.SIGNED_START, start)
+ self._add_query(QueryStringConstants.SIGNED_EXPIRY, expiry)
+ self._add_query(QueryStringConstants.SIGNED_PERMISSION, permission)
+ self._add_query(QueryStringConstants.SIGNED_IP, ip)
+ self._add_query(QueryStringConstants.SIGNED_PROTOCOL, protocol)
+ self._add_query(QueryStringConstants.SIGNED_VERSION, x_ms_version)
+
+ def add_resource(self, resource):
+ self._add_query(QueryStringConstants.SIGNED_RESOURCE, resource)
+
+ def add_id(self, policy_id):
+ self._add_query(QueryStringConstants.SIGNED_IDENTIFIER, policy_id)
+
+ def add_account(self, services, resource_types):
+ self._add_query(QueryStringConstants.SIGNED_SERVICES, services)
+ self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types)
+
+ def add_override_response_headers(self, cache_control,
+ content_disposition,
+ content_encoding,
+ content_language,
+ content_type):
+ self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_LANGUAGE, content_language)
+ self._add_query(QueryStringConstants.SIGNED_CONTENT_TYPE, content_type)
+
+ def add_account_signature(self, account_name, account_key):
+ def get_value_to_append(query):
+ return_value = self.query_dict.get(query) or ''
+ return return_value + '\n'
+
+ string_to_sign = \
+ (account_name + '\n' +
+ get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) +
+ get_value_to_append(QueryStringConstants.SIGNED_SERVICES) +
+ get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) +
+ get_value_to_append(QueryStringConstants.SIGNED_START) +
+ get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) +
+ get_value_to_append(QueryStringConstants.SIGNED_IP) +
+ get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) +
+ get_value_to_append(QueryStringConstants.SIGNED_VERSION) +
+ get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE))
+
+ self._add_query(QueryStringConstants.SIGNED_SIGNATURE,
+ sign_string(account_key, string_to_sign))
+ self.string_to_sign = string_to_sign
+
+ def get_token(self) -> str:
+ return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None])
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads.py
new file mode 100644
index 00000000..b31cfb32
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads.py
@@ -0,0 +1,604 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+from concurrent import futures
+from io import BytesIO, IOBase, SEEK_CUR, SEEK_END, SEEK_SET, UnsupportedOperation
+from itertools import islice
+from math import ceil
+from threading import Lock
+
+from azure.core.tracing.common import with_current_context
+
+from .import encode_base64, url_quote
+from .request_handlers import get_length
+from .response_handlers import return_response_headers
+
+
+_LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE = 4 * 1024 * 1024
+_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM = "{0} should be a seekable file-like/io.IOBase type stream object."
+
+
+def _parallel_uploads(executor, uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = futures.wait(running, return_when=futures.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = next(pending)
+ running.add(executor.submit(with_current_context(uploader), next_chunk))
+ except StopIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ done, _running = futures.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+def upload_data_chunks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ validate_content=None,
+ progress_hook=None,
+ **kwargs):
+
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ validate_content=validate_content,
+ progress_hook=progress_hook,
+ **kwargs)
+ if parallel:
+ with futures.ThreadPoolExecutor(max_concurrency) as executor:
+ upload_tasks = uploader.get_chunk_streams()
+ running_futures = [
+ executor.submit(with_current_context(uploader.process_chunk), u)
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = _parallel_uploads(executor, uploader.process_chunk, upload_tasks, running_futures)
+ else:
+ range_ids = [uploader.process_chunk(result) for result in uploader.get_chunk_streams()]
+ if any(range_ids):
+ return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
+ return uploader.response_headers
+
+
+def upload_substream_blocks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ with futures.ThreadPoolExecutor(max_concurrency) as executor:
+ upload_tasks = uploader.get_substream_blocks()
+ running_futures = [
+ executor.submit(with_current_context(uploader.process_substream_block), u)
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = _parallel_uploads(executor, uploader.process_substream_block, upload_tasks, running_futures)
+ else:
+ range_ids = [uploader.process_substream_block(b) for b in uploader.get_substream_blocks()]
+ if any(range_ids):
+ return sorted(range_ids)
+ return []
+
+
+class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self, service,
+ total_size,
+ chunk_size,
+ stream,
+ parallel,
+ encryptor=None,
+ padder=None,
+ progress_hook=None,
+ **kwargs):
+ self.service = service
+ self.total_size = total_size
+ self.chunk_size = chunk_size
+ self.stream = stream
+ self.parallel = parallel
+
+ # Stream management
+ self.stream_lock = Lock() if parallel else None
+
+ # Progress feedback
+ self.progress_total = 0
+ self.progress_lock = Lock() if parallel else None
+ self.progress_hook = progress_hook
+
+ # Encryption
+ self.encryptor = encryptor
+ self.padder = padder
+ self.response_headers = None
+ self.etag = None
+ self.last_modified = None
+ self.request_options = kwargs
+
+ def get_chunk_streams(self):
+ index = 0
+ while True:
+ data = b""
+ read_size = self.chunk_size
+
+ # Buffer until we either reach the end of the stream or get a whole chunk.
+ while True:
+ if self.total_size:
+ read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
+ temp = self.stream.read(read_size)
+ if not isinstance(temp, bytes):
+ raise TypeError("Blob data should be of type bytes.")
+ data += temp or b""
+
+ # We have read an empty string and so are at the end
+ # of the buffer or we have read a full chunk.
+ if temp == b"" or len(data) == self.chunk_size:
+ break
+
+ if len(data) == self.chunk_size:
+ if self.padder:
+ data = self.padder.update(data)
+ if self.encryptor:
+ data = self.encryptor.update(data)
+ yield index, data
+ else:
+ if self.padder:
+ data = self.padder.update(data) + self.padder.finalize()
+ if self.encryptor:
+ data = self.encryptor.update(data) + self.encryptor.finalize()
+ if data:
+ yield index, data
+ break
+ index += len(data)
+
+ def process_chunk(self, chunk_data):
+ chunk_bytes = chunk_data[1]
+ chunk_offset = chunk_data[0]
+ return self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
+
+ def _update_progress(self, length):
+ if self.progress_lock is not None:
+ with self.progress_lock:
+ self.progress_total += length
+ else:
+ self.progress_total += length
+
+ if self.progress_hook:
+ self.progress_hook(self.progress_total, self.total_size)
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
+ range_id = self._upload_chunk(chunk_offset, chunk_data)
+ self._update_progress(len(chunk_data))
+ return range_id
+
+ def get_substream_blocks(self):
+ assert self.chunk_size is not None
+ lock = self.stream_lock
+ blob_length = self.total_size
+
+ if blob_length is None:
+ blob_length = get_length(self.stream)
+ if blob_length is None:
+ raise ValueError("Unable to determine content length of upload data.")
+
+ blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
+ last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
+
+ for i in range(blocks):
+ index = i * self.chunk_size
+ length = last_block_size if i == blocks - 1 else self.chunk_size
+ yield index, SubStream(self.stream, index, length, lock)
+
+ def process_substream_block(self, block_data):
+ return self._upload_substream_block_with_progress(block_data[0], block_data[1])
+
+ def _upload_substream_block(self, index, block_stream):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ def _upload_substream_block_with_progress(self, index, block_stream):
+ range_id = self._upload_substream_block(index, block_stream)
+ self._update_progress(len(block_stream))
+ return range_id
+
+ def set_response_properties(self, resp):
+ self.etag = resp.etag
+ self.last_modified = resp.last_modified
+
+
+class BlockBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ kwargs.pop("modified_access_conditions", None)
+ super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # TODO: This is incorrect, but works with recording.
+ index = f'{chunk_offset:032d}'
+ block_id = encode_base64(url_quote(encode_base64(index)))
+ self.service.stage_block(
+ block_id,
+ len(chunk_data),
+ chunk_data,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ return index, block_id
+
+ def _upload_substream_block(self, index, block_stream):
+ try:
+ block_id = f'BlockId{(index//self.chunk_size):05}'
+ self.service.stage_block(
+ block_id,
+ len(block_stream),
+ block_stream,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+ return block_id
+
+
+class PageBlobChunkUploader(_ChunkUploader):
+
+ def _is_chunk_empty(self, chunk_data):
+ # read until non-zero byte is encountered
+ # if reached the end without returning, then chunk_data is all 0's
+ return not any(bytearray(chunk_data))
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ if not self._is_chunk_empty(chunk_data):
+ chunk_end = chunk_offset + len(chunk_data) - 1
+ content_range = f"bytes={chunk_offset}-{chunk_end}"
+ computed_md5 = None
+ self.response_headers = self.service.upload_pages(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ transactional_content_md5=computed_md5,
+ range=content_range,
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AppendBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ if self.current_length is None:
+ self.response_headers = self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ self.current_length = int(self.response_headers["blob_append_offset"])
+ else:
+ self.request_options['append_position_access_conditions'].append_position = \
+ self.current_length + chunk_offset
+ self.response_headers = self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class DataLakeFileChunkUploader(_ChunkUploader):
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ self.response_headers = self.service.append_data(
+ body=chunk_data,
+ position=chunk_offset,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ def _upload_substream_block(self, index, block_stream):
+ try:
+ self.service.append_data(
+ body=block_stream,
+ position=index,
+ content_length=len(block_stream),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+
+
+class FileChunkUploader(_ChunkUploader):
+
+ def _upload_chunk(self, chunk_offset, chunk_data):
+ length = len(chunk_data)
+ chunk_end = chunk_offset + length - 1
+ response = self.service.upload_range(
+ chunk_data,
+ chunk_offset,
+ length,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ return f'bytes={chunk_offset}-{chunk_end}', response
+
+ # TODO: Implement this method.
+ def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class SubStream(IOBase):
+
+ def __init__(self, wrapped_stream, stream_begin_index, length, lockObj):
+ # Python 2.7: file-like objects created with open() typically support seek(), but are not
+ # derivations of io.IOBase and thus do not implement seekable().
+ # Python > 3.0: file-like objects created with open() are derived from io.IOBase.
+ try:
+ # only the main thread runs this, so there's no need grabbing the lock
+ wrapped_stream.seek(0, SEEK_CUR)
+ except Exception as exc:
+ raise ValueError("Wrapped stream must support seek().") from exc
+
+ self._lock = lockObj
+ self._wrapped_stream = wrapped_stream
+ self._position = 0
+ self._stream_begin_index = stream_begin_index
+ self._length = length
+ self._buffer = BytesIO()
+
+ # we must avoid buffering more than necessary, and also not use up too much memory
+ # so the max buffer size is capped at 4MB
+ self._max_buffer_size = (
+ length if length < _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE else _LARGE_BLOB_UPLOAD_MAX_READ_BUFFER_SIZE
+ )
+ self._current_buffer_start = 0
+ self._current_buffer_size = 0
+ super(SubStream, self).__init__()
+
+ def __len__(self):
+ return self._length
+
+ def close(self):
+ if self._buffer:
+ self._buffer.close()
+ self._wrapped_stream = None
+ IOBase.close(self)
+
+ def fileno(self):
+ return self._wrapped_stream.fileno()
+
+ def flush(self):
+ pass
+
+ def read(self, size=None):
+ if self.closed: # pylint: disable=using-constant-test
+ raise ValueError("Stream is closed.")
+
+ if size is None:
+ size = self._length - self._position
+
+ # adjust if out of bounds
+ if size + self._position >= self._length:
+ size = self._length - self._position
+
+ # return fast
+ if size == 0 or self._buffer.closed:
+ return b""
+
+ # attempt first read from the read buffer and update position
+ read_buffer = self._buffer.read(size)
+ bytes_read = len(read_buffer)
+ bytes_remaining = size - bytes_read
+ self._position += bytes_read
+
+ # repopulate the read buffer from the underlying stream to fulfill the request
+ # ensure the seek and read operations are done atomically (only if a lock is provided)
+ if bytes_remaining > 0:
+ with self._buffer:
+ # either read in the max buffer size specified on the class
+ # or read in just enough data for the current block/sub stream
+ current_max_buffer_size = min(self._max_buffer_size, self._length - self._position)
+
+ # lock is only defined if max_concurrency > 1 (parallel uploads)
+ if self._lock:
+ with self._lock:
+ # reposition the underlying stream to match the start of the data to read
+ absolute_position = self._stream_begin_index + self._position
+ self._wrapped_stream.seek(absolute_position, SEEK_SET)
+ # If we can't seek to the right location, our read will be corrupted so fail fast.
+ if self._wrapped_stream.tell() != absolute_position:
+ raise IOError("Stream failed to seek to the desired location.")
+ buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
+ else:
+ absolute_position = self._stream_begin_index + self._position
+ # It's possible that there's connection problem during data transfer,
+ # so when we retry we don't want to read from current position of wrapped stream,
+ # instead we should seek to where we want to read from.
+ if self._wrapped_stream.tell() != absolute_position:
+ self._wrapped_stream.seek(absolute_position, SEEK_SET)
+
+ buffer_from_stream = self._wrapped_stream.read(current_max_buffer_size)
+
+ if buffer_from_stream:
+ # update the buffer with new data from the wrapped stream
+ # we need to note down the start position and size of the buffer, in case seek is performed later
+ self._buffer = BytesIO(buffer_from_stream)
+ self._current_buffer_start = self._position
+ self._current_buffer_size = len(buffer_from_stream)
+
+ # read the remaining bytes from the new buffer and update position
+ second_read_buffer = self._buffer.read(bytes_remaining)
+ read_buffer += second_read_buffer
+ self._position += len(second_read_buffer)
+
+ return read_buffer
+
+ def readable(self):
+ return True
+
+ def readinto(self, b):
+ raise UnsupportedOperation
+
+ def seek(self, offset, whence=0):
+ if whence is SEEK_SET:
+ start_index = 0
+ elif whence is SEEK_CUR:
+ start_index = self._position
+ elif whence is SEEK_END:
+ start_index = self._length
+ offset = -offset
+ else:
+ raise ValueError("Invalid argument for the 'whence' parameter.")
+
+ pos = start_index + offset
+
+ if pos > self._length:
+ pos = self._length
+ elif pos < 0:
+ pos = 0
+
+ # check if buffer is still valid
+ # if not, drop buffer
+ if pos < self._current_buffer_start or pos >= self._current_buffer_start + self._current_buffer_size:
+ self._buffer.close()
+ self._buffer = BytesIO()
+ else: # if yes seek to correct position
+ delta = pos - self._current_buffer_start
+ self._buffer.seek(delta, SEEK_SET)
+
+ self._position = pos
+ return pos
+
+ def seekable(self):
+ return True
+
+ def tell(self):
+ return self._position
+
+ def write(self):
+ raise UnsupportedOperation
+
+ def writelines(self):
+ raise UnsupportedOperation
+
+ def writeable(self):
+ return False
+
+
+class IterStreamer(object):
+ """
+ File-like streaming iterator.
+ """
+
+ def __init__(self, generator, encoding="UTF-8"):
+ self.generator = generator
+ self.iterator = iter(generator)
+ self.leftover = b""
+ self.encoding = encoding
+
+ def __len__(self):
+ return self.generator.__len__()
+
+ def __iter__(self):
+ return self.iterator
+
+ def seekable(self):
+ return False
+
+ def __next__(self):
+ return next(self.iterator)
+
+ def tell(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator does not support tell.")
+
+ def seek(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator is not seekable.")
+
+ def read(self, size):
+ data = self.leftover
+ count = len(self.leftover)
+ try:
+ while count < size:
+ chunk = self.__next__()
+ if isinstance(chunk, str):
+ chunk = chunk.encode(self.encoding)
+ data += chunk
+ count += len(chunk)
+ # This means count < size and what's leftover will be returned in this call.
+ except StopIteration:
+ self.leftover = b""
+
+ if count >= size:
+ self.leftover = data[size:]
+
+ return data[:size]
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads_async.py b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads_async.py
new file mode 100644
index 00000000..3e102ec5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/blob/_shared/uploads_async.py
@@ -0,0 +1,460 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import asyncio
+import inspect
+import threading
+from asyncio import Lock
+from io import UnsupportedOperation
+from itertools import islice
+from math import ceil
+from typing import AsyncGenerator, Union
+
+from .import encode_base64, url_quote
+from .request_handlers import get_length
+from .response_handlers import return_response_headers
+from .uploads import SubStream, IterStreamer # pylint: disable=unused-import
+
+
+async def _async_parallel_uploads(uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = await pending.__anext__()
+ running.add(asyncio.ensure_future(uploader(next_chunk)))
+ except StopAsyncIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ if running:
+ done, _running = await asyncio.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+async def _parallel_uploads(uploader, pending, running):
+ range_ids = []
+ while True:
+ # Wait for some download to finish before adding a new one
+ done, running = await asyncio.wait(running, return_when=asyncio.FIRST_COMPLETED)
+ range_ids.extend([chunk.result() for chunk in done])
+ try:
+ for _ in range(0, len(done)):
+ next_chunk = next(pending)
+ running.add(asyncio.ensure_future(uploader(next_chunk)))
+ except StopIteration:
+ break
+
+ # Wait for the remaining uploads to finish
+ if running:
+ done, _running = await asyncio.wait(running)
+ range_ids.extend([chunk.result() for chunk in done])
+ return range_ids
+
+
+async def upload_data_chunks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ upload_tasks = uploader.get_chunk_streams()
+ running_futures = []
+ for _ in range(max_concurrency):
+ try:
+ chunk = await upload_tasks.__anext__()
+ running_futures.append(asyncio.ensure_future(uploader.process_chunk(chunk)))
+ except StopAsyncIteration:
+ break
+
+ range_ids = await _async_parallel_uploads(uploader.process_chunk, upload_tasks, running_futures)
+ else:
+ range_ids = []
+ async for chunk in uploader.get_chunk_streams():
+ range_ids.append(await uploader.process_chunk(chunk))
+
+ if any(range_ids):
+ return [r[1] for r in sorted(range_ids, key=lambda r: r[0])]
+ return uploader.response_headers
+
+
+async def upload_substream_blocks(
+ service=None,
+ uploader_class=None,
+ total_size=None,
+ chunk_size=None,
+ max_concurrency=None,
+ stream=None,
+ progress_hook=None,
+ **kwargs):
+ parallel = max_concurrency > 1
+ if parallel and 'modified_access_conditions' in kwargs:
+ # Access conditions do not work with parallelism
+ kwargs['modified_access_conditions'] = None
+ uploader = uploader_class(
+ service=service,
+ total_size=total_size,
+ chunk_size=chunk_size,
+ stream=stream,
+ parallel=parallel,
+ progress_hook=progress_hook,
+ **kwargs)
+
+ if parallel:
+ upload_tasks = uploader.get_substream_blocks()
+ running_futures = [
+ asyncio.ensure_future(uploader.process_substream_block(u))
+ for u in islice(upload_tasks, 0, max_concurrency)
+ ]
+ range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures)
+ else:
+ range_ids = []
+ for block in uploader.get_substream_blocks():
+ range_ids.append(await uploader.process_substream_block(block))
+ if any(range_ids):
+ return sorted(range_ids)
+ return
+
+
+class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes
+
+ def __init__(
+ self, service,
+ total_size,
+ chunk_size,
+ stream,
+ parallel,
+ encryptor=None,
+ padder=None,
+ progress_hook=None,
+ **kwargs):
+ self.service = service
+ self.total_size = total_size
+ self.chunk_size = chunk_size
+ self.stream = stream
+ self.parallel = parallel
+
+ # Stream management
+ self.stream_lock = threading.Lock() if parallel else None
+
+ # Progress feedback
+ self.progress_total = 0
+ self.progress_lock = Lock() if parallel else None
+ self.progress_hook = progress_hook
+
+ # Encryption
+ self.encryptor = encryptor
+ self.padder = padder
+ self.response_headers = None
+ self.etag = None
+ self.last_modified = None
+ self.request_options = kwargs
+
+ async def get_chunk_streams(self):
+ index = 0
+ while True:
+ data = b''
+ read_size = self.chunk_size
+
+ # Buffer until we either reach the end of the stream or get a whole chunk.
+ while True:
+ if self.total_size:
+ read_size = min(self.chunk_size - len(data), self.total_size - (index + len(data)))
+ temp = self.stream.read(read_size)
+ if inspect.isawaitable(temp):
+ temp = await temp
+ if not isinstance(temp, bytes):
+ raise TypeError('Blob data should be of type bytes.')
+ data += temp or b""
+
+ # We have read an empty string and so are at the end
+ # of the buffer or we have read a full chunk.
+ if temp == b'' or len(data) == self.chunk_size:
+ break
+
+ if len(data) == self.chunk_size:
+ if self.padder:
+ data = self.padder.update(data)
+ if self.encryptor:
+ data = self.encryptor.update(data)
+ yield index, data
+ else:
+ if self.padder:
+ data = self.padder.update(data) + self.padder.finalize()
+ if self.encryptor:
+ data = self.encryptor.update(data) + self.encryptor.finalize()
+ if data:
+ yield index, data
+ break
+ index += len(data)
+
+ async def process_chunk(self, chunk_data):
+ chunk_bytes = chunk_data[1]
+ chunk_offset = chunk_data[0]
+ return await self._upload_chunk_with_progress(chunk_offset, chunk_bytes)
+
+ async def _update_progress(self, length):
+ if self.progress_lock is not None:
+ async with self.progress_lock:
+ self.progress_total += length
+ else:
+ self.progress_total += length
+
+ if self.progress_hook:
+ await self.progress_hook(self.progress_total, self.total_size)
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ async def _upload_chunk_with_progress(self, chunk_offset, chunk_data):
+ range_id = await self._upload_chunk(chunk_offset, chunk_data)
+ await self._update_progress(len(chunk_data))
+ return range_id
+
+ def get_substream_blocks(self):
+ assert self.chunk_size is not None
+ lock = self.stream_lock
+ blob_length = self.total_size
+
+ if blob_length is None:
+ blob_length = get_length(self.stream)
+ if blob_length is None:
+ raise ValueError("Unable to determine content length of upload data.")
+
+ blocks = int(ceil(blob_length / (self.chunk_size * 1.0)))
+ last_block_size = self.chunk_size if blob_length % self.chunk_size == 0 else blob_length % self.chunk_size
+
+ for i in range(blocks):
+ index = i * self.chunk_size
+ length = last_block_size if i == blocks - 1 else self.chunk_size
+ yield index, SubStream(self.stream, index, length, lock)
+
+ async def process_substream_block(self, block_data):
+ return await self._upload_substream_block_with_progress(block_data[0], block_data[1])
+
+ async def _upload_substream_block(self, index, block_stream):
+ raise NotImplementedError("Must be implemented by child class.")
+
+ async def _upload_substream_block_with_progress(self, index, block_stream):
+ range_id = await self._upload_substream_block(index, block_stream)
+ await self._update_progress(len(block_stream))
+ return range_id
+
+ def set_response_properties(self, resp):
+ self.etag = resp.etag
+ self.last_modified = resp.last_modified
+
+
+class BlockBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ kwargs.pop('modified_access_conditions', None)
+ super(BlockBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ # TODO: This is incorrect, but works with recording.
+ index = f'{chunk_offset:032d}'
+ block_id = encode_base64(url_quote(encode_base64(index)))
+ await self.service.stage_block(
+ block_id,
+ len(chunk_data),
+ body=chunk_data,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ return index, block_id
+
+ async def _upload_substream_block(self, index, block_stream):
+ try:
+ block_id = f'BlockId{(index//self.chunk_size):05}'
+ await self.service.stage_block(
+ block_id,
+ len(block_stream),
+ block_stream,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ finally:
+ block_stream.close()
+ return block_id
+
+
+class PageBlobChunkUploader(_ChunkUploader):
+
+ def _is_chunk_empty(self, chunk_data):
+ # read until non-zero byte is encountered
+ # if reached the end without returning, then chunk_data is all 0's
+ for each_byte in chunk_data:
+ if each_byte not in [0, b'\x00']:
+ return False
+ return True
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ # avoid uploading the empty pages
+ if not self._is_chunk_empty(chunk_data):
+ chunk_end = chunk_offset + len(chunk_data) - 1
+ content_range = f'bytes={chunk_offset}-{chunk_end}'
+ computed_md5 = None
+ self.response_headers = await self.service.upload_pages(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ transactional_content_md5=computed_md5,
+ range=content_range,
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AppendBlobChunkUploader(_ChunkUploader):
+
+ def __init__(self, *args, **kwargs):
+ super(AppendBlobChunkUploader, self).__init__(*args, **kwargs)
+ self.current_length = None
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ if self.current_length is None:
+ self.response_headers = await self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+ self.current_length = int(self.response_headers['blob_append_offset'])
+ else:
+ self.request_options['append_position_access_conditions'].append_position = \
+ self.current_length + chunk_offset
+ self.response_headers = await self.service.append_block(
+ body=chunk_data,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options)
+
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class DataLakeFileChunkUploader(_ChunkUploader):
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ self.response_headers = await self.service.append_data(
+ body=chunk_data,
+ position=chunk_offset,
+ content_length=len(chunk_data),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+
+ if not self.parallel and self.request_options.get('modified_access_conditions'):
+ self.request_options['modified_access_conditions'].if_match = self.response_headers['etag']
+
+ async def _upload_substream_block(self, index, block_stream):
+ try:
+ await self.service.append_data(
+ body=block_stream,
+ position=index,
+ content_length=len(block_stream),
+ cls=return_response_headers,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ finally:
+ block_stream.close()
+
+
+class FileChunkUploader(_ChunkUploader):
+
+ async def _upload_chunk(self, chunk_offset, chunk_data):
+ length = len(chunk_data)
+ chunk_end = chunk_offset + length - 1
+ response = await self.service.upload_range(
+ chunk_data,
+ chunk_offset,
+ length,
+ data_stream_total=self.total_size,
+ upload_stream_current=self.progress_total,
+ **self.request_options
+ )
+ range_id = f'bytes={chunk_offset}-{chunk_end}'
+ return range_id, response
+
+ # TODO: Implement this method.
+ async def _upload_substream_block(self, index, block_stream):
+ pass
+
+
+class AsyncIterStreamer():
+ """
+ File-like streaming object for AsyncGenerators.
+ """
+ def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"):
+ self.iterator = generator.__aiter__()
+ self.leftover = b""
+ self.encoding = encoding
+
+ def seekable(self):
+ return False
+
+ def tell(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator does not support tell.")
+
+ def seek(self, *args, **kwargs):
+ raise UnsupportedOperation("Data generator is not seekable.")
+
+ async def read(self, size: int) -> bytes:
+ data = self.leftover
+ count = len(self.leftover)
+ try:
+ while count < size:
+ chunk = await self.iterator.__anext__()
+ if isinstance(chunk, str):
+ chunk = chunk.encode(self.encoding)
+ data += chunk
+ count += len(chunk)
+ # This means count < size and what's leftover will be returned in this call.
+ except StopAsyncIteration:
+ self.leftover = b""
+
+ if count >= size:
+ self.leftover = data[size:]
+
+ return data[:size]