about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py694
1 files changed, 694 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py
new file mode 100644
index 00000000..ee75cd5a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py
@@ -0,0 +1,694 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import base64
+import hashlib
+import logging
+import random
+import re
+import uuid
+from io import SEEK_SET, UnsupportedOperation
+from time import time
+from typing import Any, Dict, Optional, TYPE_CHECKING
+from urllib.parse import (
+    parse_qsl,
+    urlencode,
+    urlparse,
+    urlunparse,
+)
+from wsgiref.handlers import format_date_time
+
+from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
+from azure.core.pipeline.policies import (
+    BearerTokenCredentialPolicy,
+    HeadersPolicy,
+    HTTPPolicy,
+    NetworkTraceLoggingPolicy,
+    RequestHistory,
+    SansIOHTTPPolicy
+)
+
+from .authentication import AzureSigningError, StorageHttpChallenge
+from .constants import DEFAULT_OAUTH_SCOPE
+from .models import LocationMode
+
+if TYPE_CHECKING:
+    from azure.core.credentials import TokenCredential
+    from azure.core.pipeline.transport import (  # pylint: disable=non-abstract-transport-import
+        PipelineRequest,
+        PipelineResponse
+    )
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def encode_base64(data):
+    if isinstance(data, str):
+        data = data.encode('utf-8')
+    encoded = base64.b64encode(data)
+    return encoded.decode('utf-8')
+
+
+# Are we out of retries?
+def is_exhausted(settings):
+    retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status'])
+    retry_counts = list(filter(None, retry_counts))
+    if not retry_counts:
+        return False
+    return min(retry_counts) < 0
+
+
+def retry_hook(settings, **kwargs):
+    if settings['hook']:
+        settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs)
+
+
+# Is this method/status code retryable? (Based on allowlists and control
+# variables such as the number of total retries to allow, whether to
+# respect the Retry-After header, whether this header is present, and
+# whether the returned status code is on the list of status codes to
+# be retried upon on the presence of the aforementioned header)
+def is_retry(response, mode):
+    status = response.http_response.status_code
+    if 300 <= status < 500:
+        # An exception occurred, but in most cases it was expected. Examples could
+        # include a 309 Conflict or 412 Precondition Failed.
+        if status == 404 and mode == LocationMode.SECONDARY:
+            # Response code 404 should be retried if secondary was used.
+            return True
+        if status == 408:
+            # Response code 408 is a timeout and should be retried.
+            return True
+        return False
+    if status >= 500:
+        # Response codes above 500 with the exception of 501 Not Implemented and
+        # 505 Version Not Supported indicate a server issue and should be retried.
+        if status in [501, 505]:
+            return False
+        return True
+    return False
+
+
+def is_checksum_retry(response):
+    # retry if invalid content md5
+    if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+        computed_md5 = response.http_request.headers.get('content-md5', None) or \
+                            encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+        if response.http_response.headers['content-md5'] != computed_md5:
+            return True
+    return False
+
+
+def urljoin(base_url, stub_url):
+    parsed = urlparse(base_url)
+    parsed = parsed._replace(path=parsed.path + '/' + stub_url)
+    return parsed.geturl()
+
+
+class QueueMessagePolicy(SansIOHTTPPolicy):
+
+    def on_request(self, request):
+        message_id = request.context.options.pop('queue_message_id', None)
+        if message_id:
+            request.http_request.url = urljoin(
+                request.http_request.url,
+                message_id)
+
+
+class StorageHeadersPolicy(HeadersPolicy):
+    request_id_header_name = 'x-ms-client-request-id'
+
+    def on_request(self, request: "PipelineRequest") -> None:
+        super(StorageHeadersPolicy, self).on_request(request)
+        current_time = format_date_time(time())
+        request.http_request.headers['x-ms-date'] = current_time
+
+        custom_id = request.context.options.pop('client_request_id', None)
+        request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1())
+
+    # def on_response(self, request, response):
+    #     # raise exception if the echoed client request id from the service is not identical to the one we sent
+    #     if self.request_id_header_name in response.http_response.headers:
+
+    #         client_request_id = request.http_request.headers.get(self.request_id_header_name)
+
+    #         if response.http_response.headers[self.request_id_header_name] != client_request_id:
+    #             raise AzureError(
+    #                 "Echoed client request ID: {} does not match sent client request ID: {}.  "
+    #                 "Service request ID: {}".format(
+    #                     response.http_response.headers[self.request_id_header_name], client_request_id,
+    #                     response.http_response.headers['x-ms-request-id']),
+    #                 response=response.http_response
+    #             )
+
+
+class StorageHosts(SansIOHTTPPolicy):
+
+    def __init__(self, hosts=None, **kwargs):  # pylint: disable=unused-argument
+        self.hosts = hosts
+        super(StorageHosts, self).__init__()
+
+    def on_request(self, request: "PipelineRequest") -> None:
+        request.context.options['hosts'] = self.hosts
+        parsed_url = urlparse(request.http_request.url)
+
+        # Detect what location mode we're currently requesting with
+        location_mode = LocationMode.PRIMARY
+        for key, value in self.hosts.items():
+            if parsed_url.netloc == value:
+                location_mode = key
+
+        # See if a specific location mode has been specified, and if so, redirect
+        use_location = request.context.options.pop('use_location', None)
+        if use_location:
+            # Lock retries to the specific location
+            request.context.options['retry_to_secondary'] = False
+            if use_location not in self.hosts:
+                raise ValueError(f"Attempting to use undefined host location {use_location}")
+            if use_location != location_mode:
+                # Update request URL to use the specified location
+                updated = parsed_url._replace(netloc=self.hosts[use_location])
+                request.http_request.url = updated.geturl()
+                location_mode = use_location
+
+        request.context.options['location_mode'] = location_mode
+
+
+class StorageLoggingPolicy(NetworkTraceLoggingPolicy):
+    """A policy that logs HTTP request and response to the DEBUG logger.
+
+    This accepts both global configuration, and per-request level with "enable_http_logger"
+    """
+
+    def __init__(self, logging_enable: bool = False, **kwargs) -> None:
+        self.logging_body = kwargs.pop("logging_body", False)
+        super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs)
+
+    def on_request(self, request: "PipelineRequest") -> None:
+        http_request = request.http_request
+        options = request.context.options
+        self.logging_body = self.logging_body or options.pop("logging_body", False)
+        if options.pop("logging_enable", self.enable_http_logger):
+            request.context["logging_enable"] = True
+            if not _LOGGER.isEnabledFor(logging.DEBUG):
+                return
+
+            try:
+                log_url = http_request.url
+                query_params = http_request.query
+                if 'sig' in query_params:
+                    log_url = log_url.replace(query_params['sig'], "sig=*****")
+                _LOGGER.debug("Request URL: %r", log_url)
+                _LOGGER.debug("Request method: %r", http_request.method)
+                _LOGGER.debug("Request headers:")
+                for header, value in http_request.headers.items():
+                    if header.lower() == 'authorization':
+                        value = '*****'
+                    elif header.lower() == 'x-ms-copy-source' and 'sig' in value:
+                        # take the url apart and scrub away the signed signature
+                        scheme, netloc, path, params, query, fragment = urlparse(value)
+                        parsed_qs = dict(parse_qsl(query))
+                        parsed_qs['sig'] = '*****'
+
+                        # the SAS needs to be put back together
+                        value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment))
+
+                    _LOGGER.debug("    %r: %r", header, value)
+                _LOGGER.debug("Request body:")
+
+                if self.logging_body:
+                    _LOGGER.debug(str(http_request.body))
+                else:
+                    # We don't want to log the binary data of a file upload.
+                    _LOGGER.debug("Hidden body, please use logging_body to show body")
+            except Exception as err:  # pylint: disable=broad-except
+                _LOGGER.debug("Failed to log request: %r", err)
+
+    def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+        if response.context.pop("logging_enable", self.enable_http_logger):
+            if not _LOGGER.isEnabledFor(logging.DEBUG):
+                return
+
+            try:
+                _LOGGER.debug("Response status: %r", response.http_response.status_code)
+                _LOGGER.debug("Response headers:")
+                for res_header, value in response.http_response.headers.items():
+                    _LOGGER.debug("    %r: %r", res_header, value)
+
+                # We don't want to log binary data if the response is a file.
+                _LOGGER.debug("Response content:")
+                pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
+                header = response.http_response.headers.get('content-disposition')
+                resp_content_type = response.http_response.headers.get("content-type", "")
+
+                if header and pattern.match(header):
+                    filename = header.partition('=')[2]
+                    _LOGGER.debug("File attachments: %s", filename)
+                elif resp_content_type.endswith("octet-stream"):
+                    _LOGGER.debug("Body contains binary data.")
+                elif resp_content_type.startswith("image"):
+                    _LOGGER.debug("Body contains image data.")
+
+                if self.logging_body and resp_content_type.startswith("text"):
+                    _LOGGER.debug(response.http_response.text())
+                elif self.logging_body:
+                    try:
+                        _LOGGER.debug(response.http_response.body())
+                    except ValueError:
+                        _LOGGER.debug("Body is streamable")
+
+            except Exception as err:  # pylint: disable=broad-except
+                _LOGGER.debug("Failed to log response: %s", repr(err))
+
+
+class StorageRequestHook(SansIOHTTPPolicy):
+
+    def __init__(self, **kwargs):
+        self._request_callback = kwargs.get('raw_request_hook')
+        super(StorageRequestHook, self).__init__()
+
+    def on_request(self, request: "PipelineRequest") -> None:
+        request_callback = request.context.options.pop('raw_request_hook', self._request_callback)
+        if request_callback:
+            request_callback(request)
+
+
+class StorageResponseHook(HTTPPolicy):
+
+    def __init__(self, **kwargs):
+        self._response_callback = kwargs.get('raw_response_hook')
+        super(StorageResponseHook, self).__init__()
+
+    def send(self, request: "PipelineRequest") -> "PipelineResponse":
+        # Values could be 0
+        data_stream_total = request.context.get('data_stream_total')
+        if data_stream_total is None:
+            data_stream_total = request.context.options.pop('data_stream_total', None)
+        download_stream_current = request.context.get('download_stream_current')
+        if download_stream_current is None:
+            download_stream_current = request.context.options.pop('download_stream_current', None)
+        upload_stream_current = request.context.get('upload_stream_current')
+        if upload_stream_current is None:
+            upload_stream_current = request.context.options.pop('upload_stream_current', None)
+
+        response_callback = request.context.get('response_callback') or \
+            request.context.options.pop('raw_response_hook', self._response_callback)
+
+        response = self.next.send(request)
+
+        will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response)
+        # Auth error could come from Bearer challenge, in which case this request will be made again
+        is_auth_error = response.http_response.status_code == 401
+        should_update_counts = not (will_retry or is_auth_error)
+
+        if should_update_counts and download_stream_current is not None:
+            download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
+            if data_stream_total is None:
+                content_range = response.http_response.headers.get('Content-Range')
+                if content_range:
+                    data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
+                else:
+                    data_stream_total = download_stream_current
+        elif should_update_counts and upload_stream_current is not None:
+            upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
+        for pipeline_obj in [request, response]:
+            if hasattr(pipeline_obj, 'context'):
+                pipeline_obj.context['data_stream_total'] = data_stream_total
+                pipeline_obj.context['download_stream_current'] = download_stream_current
+                pipeline_obj.context['upload_stream_current'] = upload_stream_current
+        if response_callback:
+            response_callback(response)
+            request.context['response_callback'] = response_callback
+        return response
+
+
+class StorageContentValidation(SansIOHTTPPolicy):
+    """A simple policy that sends the given headers
+    with the request.
+
+    This will overwrite any headers already defined in the request.
+    """
+    header_name = 'Content-MD5'
+
+    def __init__(self, **kwargs: Any) -> None:  # pylint: disable=unused-argument
+        super(StorageContentValidation, self).__init__()
+
+    @staticmethod
+    def get_content_md5(data):
+        # Since HTTP does not differentiate between no content and empty content,
+        # we have to perform a None check.
+        data = data or b""
+        md5 = hashlib.md5() # nosec
+        if isinstance(data, bytes):
+            md5.update(data)
+        elif hasattr(data, 'read'):
+            pos = 0
+            try:
+                pos = data.tell()
+            except:  # pylint: disable=bare-except
+                pass
+            for chunk in iter(lambda: data.read(4096), b""):
+                md5.update(chunk)
+            try:
+                data.seek(pos, SEEK_SET)
+            except (AttributeError, IOError) as exc:
+                raise ValueError("Data should be bytes or a seekable file-like object.") from exc
+        else:
+            raise ValueError("Data should be bytes or a seekable file-like object.")
+
+        return md5.digest()
+
+    def on_request(self, request: "PipelineRequest") -> None:
+        validate_content = request.context.options.pop('validate_content', False)
+        if validate_content and request.http_request.method != 'GET':
+            computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data))
+            request.http_request.headers[self.header_name] = computed_md5
+            request.context['validate_content_md5'] = computed_md5
+        request.context['validate_content'] = validate_content
+
+    def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
+        if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
+            computed_md5 = request.context.get('validate_content_md5') or \
+                encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
+            if response.http_response.headers['content-md5'] != computed_md5:
+                raise AzureError((
+                    f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', "
+                    f"computed value is '{computed_md5}'."),
+                    response=response.http_response
+                )
+
+
+class StorageRetryPolicy(HTTPPolicy):
+    """
+    The base class for Exponential and Linear retries containing shared code.
+    """
+
+    total_retries: int
+    """The max number of retries."""
+    connect_retries: int
+    """The max number of connect retries."""
+    retry_read: int
+    """The max number of read retries."""
+    retry_status: int
+    """The max number of status retries."""
+    retry_to_secondary: bool
+    """Whether the secondary endpoint should be retried."""
+
+    def __init__(self, **kwargs: Any) -> None:
+        self.total_retries = kwargs.pop('retry_total', 10)
+        self.connect_retries = kwargs.pop('retry_connect', 3)
+        self.read_retries = kwargs.pop('retry_read', 3)
+        self.status_retries = kwargs.pop('retry_status', 3)
+        self.retry_to_secondary = kwargs.pop('retry_to_secondary', False)
+        super(StorageRetryPolicy, self).__init__()
+
+    def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None:
+        """
+        A function which sets the next host location on the request, if applicable.
+
+        :param Dict[str, Any]] settings: The configurable values pertaining to the next host location.
+        :param PipelineRequest request: A pipeline request object.
+        """
+        if settings['hosts'] and all(settings['hosts'].values()):
+            url = urlparse(request.url)
+            # If there's more than one possible location, retry to the alternative
+            if settings['mode'] == LocationMode.PRIMARY:
+                settings['mode'] = LocationMode.SECONDARY
+            else:
+                settings['mode'] = LocationMode.PRIMARY
+            updated = url._replace(netloc=settings['hosts'].get(settings['mode']))
+            request.url = updated.geturl()
+
+    def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]:
+        body_position = None
+        if hasattr(request.http_request.body, 'read'):
+            try:
+                body_position = request.http_request.body.tell()
+            except (AttributeError, UnsupportedOperation):
+                # if body position cannot be obtained, then retries will not work
+                pass
+        options = request.context.options
+        return {
+            'total': options.pop("retry_total", self.total_retries),
+            'connect': options.pop("retry_connect", self.connect_retries),
+            'read': options.pop("retry_read", self.read_retries),
+            'status': options.pop("retry_status", self.status_retries),
+            'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary),
+            'mode': options.pop("location_mode", LocationMode.PRIMARY),
+            'hosts': options.pop("hosts", None),
+            'hook': options.pop("retry_hook", None),
+            'body_position': body_position,
+            'count': 0,
+            'history': []
+        }
+
+    def get_backoff_time(self, settings: Dict[str, Any]) -> float:  # pylint: disable=unused-argument
+        """ Formula for computing the current backoff.
+        Should be calculated by child class.
+
+        :param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
+        :returns: The backoff time.
+        :rtype: float
+        """
+        return 0
+
+    def sleep(self, settings, transport):
+        backoff = self.get_backoff_time(settings)
+        if not backoff or backoff < 0:
+            return
+        transport.sleep(backoff)
+
+    def increment(
+        self, settings: Dict[str, Any],
+        request: "PipelineRequest",
+        response: Optional["PipelineResponse"] = None,
+        error: Optional[AzureError] = None
+    ) -> bool:
+        """Increment the retry counters.
+
+        :param Dict[str, Any] settings: The configurable values pertaining to the increment operation.
+        :param PipelineRequest request: A pipeline request object.
+        :param Optional[PipelineResponse] response: A pipeline response object.
+        :param Optional[AzureError] error: An error encountered during the request, or
+            None if the response was received successfully.
+        :returns: Whether the retry attempts are exhausted.
+        :rtype: bool
+        """
+        settings['total'] -= 1
+
+        if error and isinstance(error, ServiceRequestError):
+            # Errors when we're fairly sure that the server did not receive the
+            # request, so it should be safe to retry.
+            settings['connect'] -= 1
+            settings['history'].append(RequestHistory(request, error=error))
+
+        elif error and isinstance(error, ServiceResponseError):
+            # Errors that occur after the request has been started, so we should
+            # assume that the server began processing it.
+            settings['read'] -= 1
+            settings['history'].append(RequestHistory(request, error=error))
+
+        else:
+            # Incrementing because of a server error like a 500 in
+            # status_forcelist and a the given method is in the allowlist
+            if response:
+                settings['status'] -= 1
+                settings['history'].append(RequestHistory(request, http_response=response))
+
+        if not is_exhausted(settings):
+            if request.method not in ['PUT'] and settings['retry_secondary']:
+                self._set_next_host_location(settings, request)
+
+            # rewind the request body if it is a stream
+            if request.body and hasattr(request.body, 'read'):
+                # no position was saved, then retry would not work
+                if settings['body_position'] is None:
+                    return False
+                try:
+                    # attempt to rewind the body to the initial position
+                    request.body.seek(settings['body_position'], SEEK_SET)
+                except (UnsupportedOperation, ValueError):
+                    # if body is not seekable, then retry would not work
+                    return False
+            settings['count'] += 1
+            return True
+        return False
+
+    def send(self, request):
+        retries_remaining = True
+        response = None
+        retry_settings = self.configure_retries(request)
+        while retries_remaining:
+            try:
+                response = self.next.send(request)
+                if is_retry(response, retry_settings['mode']) or is_checksum_retry(response):
+                    retries_remaining = self.increment(
+                        retry_settings,
+                        request=request.http_request,
+                        response=response.http_response)
+                    if retries_remaining:
+                        retry_hook(
+                            retry_settings,
+                            request=request.http_request,
+                            response=response.http_response,
+                            error=None)
+                        self.sleep(retry_settings, request.context.transport)
+                        continue
+                break
+            except AzureError as err:
+                if isinstance(err, AzureSigningError):
+                    raise
+                retries_remaining = self.increment(
+                    retry_settings, request=request.http_request, error=err)
+                if retries_remaining:
+                    retry_hook(
+                        retry_settings,
+                        request=request.http_request,
+                        response=None,
+                        error=err)
+                    self.sleep(retry_settings, request.context.transport)
+                    continue
+                raise err
+        if retry_settings['history']:
+            response.context['history'] = retry_settings['history']
+        response.http_response.location_mode = retry_settings['mode']
+        return response
+
+
+class ExponentialRetry(StorageRetryPolicy):
+    """Exponential retry."""
+
+    initial_backoff: int
+    """The initial backoff interval, in seconds, for the first retry."""
+    increment_base: int
+    """The base, in seconds, to increment the initial_backoff by after the
+    first retry."""
+    random_jitter_range: int
+    """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+    def __init__(
+        self, initial_backoff: int = 15,
+        increment_base: int = 3,
+        retry_total: int = 3,
+        retry_to_secondary: bool = False,
+        random_jitter_range: int = 3,
+        **kwargs: Any
+    ) -> None:
+        """
+        Constructs an Exponential retry object. The initial_backoff is used for
+        the first retry. Subsequent retries are retried after initial_backoff +
+        increment_power^retry_count seconds.
+
+        :param int initial_backoff:
+            The initial backoff interval, in seconds, for the first retry.
+        :param int increment_base:
+            The base, in seconds, to increment the initial_backoff by after the
+            first retry.
+        :param int retry_total:
+            The maximum number of retry attempts.
+        :param bool retry_to_secondary:
+            Whether the request should be retried to secondary, if able. This should
+            only be enabled of RA-GRS accounts are used and potentially stale data
+            can be handled.
+        :param int random_jitter_range:
+            A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+            For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+        """
+        self.initial_backoff = initial_backoff
+        self.increment_base = increment_base
+        self.random_jitter_range = random_jitter_range
+        super(ExponentialRetry, self).__init__(
+            retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+    def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+        """
+        Calculates how long to sleep before retrying.
+
+        :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time.
+        :returns:
+            A float indicating how long to wait before retrying the request,
+            or None to indicate no retry should be performed.
+        :rtype: float
+        """
+        random_generator = random.Random()
+        backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
+        random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
+        random_range_end = backoff + self.random_jitter_range
+        return random_generator.uniform(random_range_start, random_range_end)
+
+
+class LinearRetry(StorageRetryPolicy):
+    """Linear retry."""
+
+    initial_backoff: int
+    """The backoff interval, in seconds, between retries."""
+    random_jitter_range: int
+    """A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
+
+    def __init__(
+        self, backoff: int = 15,
+        retry_total: int = 3,
+        retry_to_secondary: bool = False,
+        random_jitter_range: int = 3,
+        **kwargs: Any
+    ) -> None:
+        """
+        Constructs a Linear retry object.
+
+        :param int backoff:
+            The backoff interval, in seconds, between retries.
+        :param int retry_total:
+            The maximum number of retry attempts.
+        :param bool retry_to_secondary:
+            Whether the request should be retried to secondary, if able. This should
+            only be enabled of RA-GRS accounts are used and potentially stale data
+            can be handled.
+        :param int random_jitter_range:
+            A number in seconds which indicates a range to jitter/randomize for the back-off interval.
+            For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
+        """
+        self.backoff = backoff
+        self.random_jitter_range = random_jitter_range
+        super(LinearRetry, self).__init__(
+            retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
+
+    def get_backoff_time(self, settings: Dict[str, Any]) -> float:
+        """
+        Calculates how long to sleep before retrying.
+
+        :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
+        :returns:
+            A float indicating how long to wait before retrying the request,
+            or None to indicate no retry should be performed.
+        :rtype: float
+        """
+        random_generator = random.Random()
+        # the backoff interval normally does not change, however there is the possibility
+        # that it was modified by accessing the property directly after initializing the object
+        random_range_start = self.backoff - self.random_jitter_range \
+            if self.backoff > self.random_jitter_range else 0
+        random_range_end = self.backoff + self.random_jitter_range
+        return random_generator.uniform(random_range_start, random_range_end)
+
+
+class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
+    """ Custom Bearer token credential policy for following Storage Bearer challenges """
+
+    def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None:
+        super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
+
+    def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
+        try:
+            auth_header = response.http_response.headers.get("WWW-Authenticate")
+            challenge = StorageHttpChallenge(auth_header)
+        except ValueError:
+            return False
+
+        scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
+        self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
+
+        return True