# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import base64 import hashlib import logging import random import re import uuid from io import SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING from urllib.parse import ( parse_qsl, urlencode, urlparse, urlunparse, ) from wsgiref.handlers import format_date_time from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import ( BearerTokenCredentialPolicy, HeadersPolicy, HTTPPolicy, NetworkTraceLoggingPolicy, RequestHistory, SansIOHTTPPolicy ) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE from .models import LocationMode if TYPE_CHECKING: from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, PipelineResponse ) _LOGGER = logging.getLogger(__name__) def encode_base64(data): if isinstance(data, str): data = data.encode('utf-8') encoded = base64.b64encode(data) return encoded.decode('utf-8') # Are we out of retries? def is_exhausted(settings): retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False return min(retry_counts) < 0 def retry_hook(settings, **kwargs): if settings['hook']: settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) # Is this method/status code retryable? (Based on allowlists and control # variables such as the number of total retries to allow, whether to # respect the Retry-After header, whether this header is present, and # whether the returned status code is on the list of status codes to # be retried upon on the presence of the aforementioned header) def is_retry(response, mode): status = response.http_response.status_code if 300 <= status < 500: # An exception occurred, but in most cases it was expected. Examples could # include a 309 Conflict or 412 Precondition Failed. if status == 404 and mode == LocationMode.SECONDARY: # Response code 404 should be retried if secondary was used. return True if status == 408: # Response code 408 is a timeout and should be retried. return True return False if status >= 500: # Response codes above 500 with the exception of 501 Not Implemented and # 505 Version Not Supported indicate a server issue and should be retried. if status in [501, 505]: return False return True return False def is_checksum_retry(response): # retry if invalid content md5 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: return True return False def urljoin(base_url, stub_url): parsed = urlparse(base_url) parsed = parsed._replace(path=parsed.path + '/' + stub_url) return parsed.geturl() class QueueMessagePolicy(SansIOHTTPPolicy): def on_request(self, request): message_id = request.context.options.pop('queue_message_id', None) if message_id: request.http_request.url = urljoin( request.http_request.url, message_id) class StorageHeadersPolicy(HeadersPolicy): request_id_header_name = 'x-ms-client-request-id' def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) request.http_request.headers['x-ms-date'] = current_time custom_id = request.context.options.pop('client_request_id', None) request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent # if self.request_id_header_name in response.http_response.headers: # client_request_id = request.http_request.headers.get(self.request_id_header_name) # if response.http_response.headers[self.request_id_header_name] != client_request_id: # raise AzureError( # "Echoed client request ID: {} does not match sent client request ID: {}. " # "Service request ID: {}".format( # response.http_response.headers[self.request_id_header_name], client_request_id, # response.http_response.headers['x-ms-request-id']), # response=response.http_response # ) class StorageHosts(SansIOHTTPPolicy): def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument self.hosts = hosts super(StorageHosts, self).__init__() def on_request(self, request: "PipelineRequest") -> None: request.context.options['hosts'] = self.hosts parsed_url = urlparse(request.http_request.url) # Detect what location mode we're currently requesting with location_mode = LocationMode.PRIMARY for key, value in self.hosts.items(): if parsed_url.netloc == value: location_mode = key # See if a specific location mode has been specified, and if so, redirect use_location = request.context.options.pop('use_location', None) if use_location: # Lock retries to the specific location request.context.options['retry_to_secondary'] = False if use_location not in self.hosts: raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) request.http_request.url = updated.geturl() location_mode = use_location request.context.options['location_mode'] = location_mode class StorageLoggingPolicy(NetworkTraceLoggingPolicy): """A policy that logs HTTP request and response to the DEBUG logger. This accepts both global configuration, and per-request level with "enable_http_logger" """ def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request options = request.context.options self.logging_body = self.logging_body or options.pop("logging_body", False) if options.pop("logging_enable", self.enable_http_logger): request.context["logging_enable"] = True if not _LOGGER.isEnabledFor(logging.DEBUG): return try: log_url = http_request.url query_params = http_request.query if 'sig' in query_params: log_url = log_url.replace(query_params['sig'], "sig=*****") _LOGGER.debug("Request URL: %r", log_url) _LOGGER.debug("Request method: %r", http_request.method) _LOGGER.debug("Request headers:") for header, value in http_request.headers.items(): if header.lower() == 'authorization': value = '*****' elif header.lower() == 'x-ms-copy-source' and 'sig' in value: # take the url apart and scrub away the signed signature scheme, netloc, path, params, query, fragment = urlparse(value) parsed_qs = dict(parse_qsl(query)) parsed_qs['sig'] = '*****' # the SAS needs to be put back together value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") if self.logging_body: _LOGGER.debug(str(http_request.body)) else: # We don't want to log the binary data of a file upload. _LOGGER.debug("Hidden body, please use logging_body to show body") except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log request: %r", err) def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.pop("logging_enable", self.enable_http_logger): if not _LOGGER.isEnabledFor(logging.DEBUG): return try: _LOGGER.debug("Response status: %r", response.http_response.status_code) _LOGGER.debug("Response headers:") for res_header, value in response.http_response.headers.items(): _LOGGER.debug(" %r: %r", res_header, value) # We don't want to log binary data if the response is a file. _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get('content-disposition') resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): filename = header.partition('=')[2] _LOGGER.debug("File attachments: %s", filename) elif resp_content_type.endswith("octet-stream"): _LOGGER.debug("Body contains binary data.") elif resp_content_type.startswith("image"): _LOGGER.debug("Body contains image data.") if self.logging_body and resp_content_type.startswith("text"): _LOGGER.debug(response.http_response.text()) elif self.logging_body: try: _LOGGER.debug(response.http_response.body()) except ValueError: _LOGGER.debug("Body is streamable") except Exception as err: # pylint: disable=broad-except _LOGGER.debug("Failed to log response: %s", repr(err)) class StorageRequestHook(SansIOHTTPPolicy): def __init__(self, **kwargs): self._request_callback = kwargs.get('raw_request_hook') super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: request_callback = request.context.options.pop('raw_request_hook', self._request_callback) if request_callback: request_callback(request) class StorageResponseHook(HTTPPolicy): def __init__(self, **kwargs): self._response_callback = kwargs.get('raw_response_hook') super(StorageResponseHook, self).__init__() def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 data_stream_total = request.context.get('data_stream_total') if data_stream_total is None: data_stream_total = request.context.options.pop('data_stream_total', None) download_stream_current = request.context.get('download_stream_current') if download_stream_current is None: download_stream_current = request.context.options.pop('download_stream_current', None) upload_stream_current = request.context.get('upload_stream_current') if upload_stream_current is None: upload_stream_current = request.context.options.pop('upload_stream_current', None) response_callback = request.context.get('response_callback') or \ request.context.options.pop('raw_response_hook', self._response_callback) response = self.next.send(request) will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) if data_stream_total is None: content_range = response.http_response.headers.get('Content-Range') if content_range: data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, 'context'): pipeline_obj.context['data_stream_total'] = data_stream_total pipeline_obj.context['download_stream_current'] = download_stream_current pipeline_obj.context['upload_stream_current'] = upload_stream_current if response_callback: response_callback(response) request.context['response_callback'] = response_callback return response class StorageContentValidation(SansIOHTTPPolicy): """A simple policy that sends the given headers with the request. This will overwrite any headers already defined in the request. """ header_name = 'Content-MD5' def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @staticmethod def get_content_md5(data): # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = data or b"" md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) elif hasattr(data, 'read'): pos = 0 try: pos = data.tell() except: # pylint: disable=bare-except pass for chunk in iter(lambda: data.read(4096), b""): md5.update(chunk) try: data.seek(pos, SEEK_SET) except (AttributeError, IOError) as exc: raise ValueError("Data should be bytes or a seekable file-like object.") from exc else: raise ValueError("Data should be bytes or a seekable file-like object.") return md5.digest() def on_request(self, request: "PipelineRequest") -> None: validate_content = request.context.options.pop('validate_content', False) if validate_content and request.http_request.method != 'GET': computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) request.http_request.headers[self.header_name] = computed_md5 request.context['validate_content_md5'] = computed_md5 request.context['validate_content'] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = request.context.get('validate_content_md5') or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: raise AzureError(( f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " f"computed value is '{computed_md5}'."), response=response.http_response ) class StorageRetryPolicy(HTTPPolicy): """ The base class for Exponential and Linear retries containing shared code. """ total_retries: int """The max number of retries.""" connect_retries: int """The max number of connect retries.""" retry_read: int """The max number of read retries.""" retry_status: int """The max number of status retries.""" retry_to_secondary: bool """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: self.total_retries = kwargs.pop('retry_total', 10) self.connect_retries = kwargs.pop('retry_connect', 3) self.read_retries = kwargs.pop('retry_read', 3) self.status_retries = kwargs.pop('retry_status', 3) self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) super(StorageRetryPolicy, self).__init__() def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ if settings['hosts'] and all(settings['hosts'].values()): url = urlparse(request.url) # If there's more than one possible location, retry to the alternative if settings['mode'] == LocationMode.PRIMARY: settings['mode'] = LocationMode.SECONDARY else: settings['mode'] = LocationMode.PRIMARY updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) request.url = updated.geturl() def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: body_position = None if hasattr(request.http_request.body, 'read'): try: body_position = request.http_request.body.tell() except (AttributeError, UnsupportedOperation): # if body position cannot be obtained, then retries will not work pass options = request.context.options return { 'total': options.pop("retry_total", self.total_retries), 'connect': options.pop("retry_connect", self.connect_retries), 'read': options.pop("retry_read", self.read_retries), 'status': options.pop("retry_status", self.status_retries), 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), 'mode': options.pop("location_mode", LocationMode.PRIMARY), 'hosts': options.pop("hosts", None), 'hook': options.pop("retry_hook", None), 'body_position': body_position, 'count': 0, 'history': [] } def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """ Formula for computing the current backoff. Should be calculated by child class. :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. :returns: The backoff time. :rtype: float """ return 0 def sleep(self, settings, transport): backoff = self.get_backoff_time(settings) if not backoff or backoff < 0: return transport.sleep(backoff) def increment( self, settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, error: Optional[AzureError] = None ) -> bool: """Increment the retry counters. :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. :param PipelineRequest request: A pipeline request object. :param Optional[PipelineResponse] response: A pipeline response object. :param Optional[AzureError] error: An error encountered during the request, or None if the response was received successfully. :returns: Whether the retry attempts are exhausted. :rtype: bool """ settings['total'] -= 1 if error and isinstance(error, ServiceRequestError): # Errors when we're fairly sure that the server did not receive the # request, so it should be safe to retry. settings['connect'] -= 1 settings['history'].append(RequestHistory(request, error=error)) elif error and isinstance(error, ServiceResponseError): # Errors that occur after the request has been started, so we should # assume that the server began processing it. settings['read'] -= 1 settings['history'].append(RequestHistory(request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: settings['status'] -= 1 settings['history'].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): if request.method not in ['PUT'] and settings['retry_secondary']: self._set_next_host_location(settings, request) # rewind the request body if it is a stream if request.body and hasattr(request.body, 'read'): # no position was saved, then retry would not work if settings['body_position'] is None: return False try: # attempt to rewind the body to the initial position request.body.seek(settings['body_position'], SEEK_SET) except (UnsupportedOperation, ValueError): # if body is not seekable, then retry would not work return False settings['count'] += 1 return True return False def send(self, request): retries_remaining = True response = None retry_settings = self.configure_retries(request) while retries_remaining: try: response = self.next.send(request) if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, response=response.http_response) if retries_remaining: retry_hook( retry_settings, request=request.http_request, response=response.http_response, error=None) self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise retries_remaining = self.increment( retry_settings, request=request.http_request, error=err) if retries_remaining: retry_hook( retry_settings, request=request.http_request, response=None, error=err) self.sleep(retry_settings, request.context.transport) continue raise err if retry_settings['history']: response.context['history'] = retry_settings['history'] response.http_response.location_mode = retry_settings['mode'] return response class ExponentialRetry(StorageRetryPolicy): """Exponential retry.""" initial_backoff: int """The initial backoff interval, in seconds, for the first retry.""" increment_base: int """The base, in seconds, to increment the initial_backoff by after the first retry.""" random_jitter_range: int """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( self, initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs: Any ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for the first retry. Subsequent retries are retried after initial_backoff + increment_power^retry_count seconds. :param int initial_backoff: The initial backoff interval, in seconds, for the first retry. :param int increment_base: The base, in seconds, to increment the initial_backoff by after the first retry. :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should only be enabled of RA-GRS accounts are used and potentially stale data can be handled. :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. """ self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range super(ExponentialRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. :returns: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class LinearRetry(StorageRetryPolicy): """Linear retry.""" initial_backoff: int """The backoff interval, in seconds, between retries.""" random_jitter_range: int """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( self, backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, **kwargs: Any ) -> None: """ Constructs a Linear retry object. :param int backoff: The backoff interval, in seconds, between retries. :param int retry_total: The maximum number of retry attempts. :param bool retry_to_secondary: Whether the request should be retried to secondary, if able. This should only be enabled of RA-GRS accounts are used and potentially stale data can be handled. :param int random_jitter_range: A number in seconds which indicates a range to jitter/randomize for the back-off interval. For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. """ self.backoff = backoff self.random_jitter_range = random_jitter_range super(LinearRetry, self).__init__( retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. :returns: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object random_range_start = self.backoff - self.random_jitter_range \ if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """ Custom Bearer token credential policy for following Storage Bearer challenges """ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) except ValueError: return False scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE self.authorize_request(request, scope, tenant_id=challenge.tenant_id) return True