diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py | 694 |
1 files changed, 694 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py new file mode 100644 index 00000000..ee75cd5a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/policies.py @@ -0,0 +1,694 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import base64 +import hashlib +import logging +import random +import re +import uuid +from io import SEEK_SET, UnsupportedOperation +from time import time +from typing import Any, Dict, Optional, TYPE_CHECKING +from urllib.parse import ( + parse_qsl, + urlencode, + urlparse, + urlunparse, +) +from wsgiref.handlers import format_date_time + +from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError +from azure.core.pipeline.policies import ( + BearerTokenCredentialPolicy, + HeadersPolicy, + HTTPPolicy, + NetworkTraceLoggingPolicy, + RequestHistory, + SansIOHTTPPolicy +) + +from .authentication import AzureSigningError, StorageHttpChallenge +from .constants import DEFAULT_OAUTH_SCOPE +from .models import LocationMode + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + PipelineRequest, + PipelineResponse + ) + + +_LOGGER = logging.getLogger(__name__) + + +def encode_base64(data): + if isinstance(data, str): + data = data.encode('utf-8') + encoded = base64.b64encode(data) + return encoded.decode('utf-8') + + +# Are we out of retries? +def is_exhausted(settings): + retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = list(filter(None, retry_counts)) + if not retry_counts: + return False + return min(retry_counts) < 0 + + +def retry_hook(settings, **kwargs): + if settings['hook']: + settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + + +# Is this method/status code retryable? (Based on allowlists and control +# variables such as the number of total retries to allow, whether to +# respect the Retry-After header, whether this header is present, and +# whether the returned status code is on the list of status codes to +# be retried upon on the presence of the aforementioned header) +def is_retry(response, mode): + status = response.http_response.status_code + if 300 <= status < 500: + # An exception occurred, but in most cases it was expected. Examples could + # include a 309 Conflict or 412 Precondition Failed. + if status == 404 and mode == LocationMode.SECONDARY: + # Response code 404 should be retried if secondary was used. + return True + if status == 408: + # Response code 408 is a timeout and should be retried. + return True + return False + if status >= 500: + # Response codes above 500 with the exception of 501 Not Implemented and + # 505 Version Not Supported indicate a server issue and should be retried. + if status in [501, 505]: + return False + return True + return False + + +def is_checksum_retry(response): + # retry if invalid content md5 + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + computed_md5 = response.http_request.headers.get('content-md5', None) or \ + encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) + if response.http_response.headers['content-md5'] != computed_md5: + return True + return False + + +def urljoin(base_url, stub_url): + parsed = urlparse(base_url) + parsed = parsed._replace(path=parsed.path + '/' + stub_url) + return parsed.geturl() + + +class QueueMessagePolicy(SansIOHTTPPolicy): + + def on_request(self, request): + message_id = request.context.options.pop('queue_message_id', None) + if message_id: + request.http_request.url = urljoin( + request.http_request.url, + message_id) + + +class StorageHeadersPolicy(HeadersPolicy): + request_id_header_name = 'x-ms-client-request-id' + + def on_request(self, request: "PipelineRequest") -> None: + super(StorageHeadersPolicy, self).on_request(request) + current_time = format_date_time(time()) + request.http_request.headers['x-ms-date'] = current_time + + custom_id = request.context.options.pop('client_request_id', None) + request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + + # def on_response(self, request, response): + # # raise exception if the echoed client request id from the service is not identical to the one we sent + # if self.request_id_header_name in response.http_response.headers: + + # client_request_id = request.http_request.headers.get(self.request_id_header_name) + + # if response.http_response.headers[self.request_id_header_name] != client_request_id: + # raise AzureError( + # "Echoed client request ID: {} does not match sent client request ID: {}. " + # "Service request ID: {}".format( + # response.http_response.headers[self.request_id_header_name], client_request_id, + # response.http_response.headers['x-ms-request-id']), + # response=response.http_response + # ) + + +class StorageHosts(SansIOHTTPPolicy): + + def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument + self.hosts = hosts + super(StorageHosts, self).__init__() + + def on_request(self, request: "PipelineRequest") -> None: + request.context.options['hosts'] = self.hosts + parsed_url = urlparse(request.http_request.url) + + # Detect what location mode we're currently requesting with + location_mode = LocationMode.PRIMARY + for key, value in self.hosts.items(): + if parsed_url.netloc == value: + location_mode = key + + # See if a specific location mode has been specified, and if so, redirect + use_location = request.context.options.pop('use_location', None) + if use_location: + # Lock retries to the specific location + request.context.options['retry_to_secondary'] = False + if use_location not in self.hosts: + raise ValueError(f"Attempting to use undefined host location {use_location}") + if use_location != location_mode: + # Update request URL to use the specified location + updated = parsed_url._replace(netloc=self.hosts[use_location]) + request.http_request.url = updated.geturl() + location_mode = use_location + + request.context.options['location_mode'] = location_mode + + +class StorageLoggingPolicy(NetworkTraceLoggingPolicy): + """A policy that logs HTTP request and response to the DEBUG logger. + + This accepts both global configuration, and per-request level with "enable_http_logger" + """ + + def __init__(self, logging_enable: bool = False, **kwargs) -> None: + self.logging_body = kwargs.pop("logging_body", False) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + + def on_request(self, request: "PipelineRequest") -> None: + http_request = request.http_request + options = request.context.options + self.logging_body = self.logging_body or options.pop("logging_body", False) + if options.pop("logging_enable", self.enable_http_logger): + request.context["logging_enable"] = True + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + log_url = http_request.url + query_params = http_request.query + if 'sig' in query_params: + log_url = log_url.replace(query_params['sig'], "sig=*****") + _LOGGER.debug("Request URL: %r", log_url) + _LOGGER.debug("Request method: %r", http_request.method) + _LOGGER.debug("Request headers:") + for header, value in http_request.headers.items(): + if header.lower() == 'authorization': + value = '*****' + elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + # take the url apart and scrub away the signed signature + scheme, netloc, path, params, query, fragment = urlparse(value) + parsed_qs = dict(parse_qsl(query)) + parsed_qs['sig'] = '*****' + + # the SAS needs to be put back together + value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + + _LOGGER.debug(" %r: %r", header, value) + _LOGGER.debug("Request body:") + + if self.logging_body: + _LOGGER.debug(str(http_request.body)) + else: + # We don't want to log the binary data of a file upload. + _LOGGER.debug("Hidden body, please use logging_body to show body") + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log request: %r", err) + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + if response.context.pop("logging_enable", self.enable_http_logger): + if not _LOGGER.isEnabledFor(logging.DEBUG): + return + + try: + _LOGGER.debug("Response status: %r", response.http_response.status_code) + _LOGGER.debug("Response headers:") + for res_header, value in response.http_response.headers.items(): + _LOGGER.debug(" %r: %r", res_header, value) + + # We don't want to log binary data if the response is a file. + _LOGGER.debug("Response content:") + pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) + header = response.http_response.headers.get('content-disposition') + resp_content_type = response.http_response.headers.get("content-type", "") + + if header and pattern.match(header): + filename = header.partition('=')[2] + _LOGGER.debug("File attachments: %s", filename) + elif resp_content_type.endswith("octet-stream"): + _LOGGER.debug("Body contains binary data.") + elif resp_content_type.startswith("image"): + _LOGGER.debug("Body contains image data.") + + if self.logging_body and resp_content_type.startswith("text"): + _LOGGER.debug(response.http_response.text()) + elif self.logging_body: + try: + _LOGGER.debug(response.http_response.body()) + except ValueError: + _LOGGER.debug("Body is streamable") + + except Exception as err: # pylint: disable=broad-except + _LOGGER.debug("Failed to log response: %s", repr(err)) + + +class StorageRequestHook(SansIOHTTPPolicy): + + def __init__(self, **kwargs): + self._request_callback = kwargs.get('raw_request_hook') + super(StorageRequestHook, self).__init__() + + def on_request(self, request: "PipelineRequest") -> None: + request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + if request_callback: + request_callback(request) + + +class StorageResponseHook(HTTPPolicy): + + def __init__(self, **kwargs): + self._response_callback = kwargs.get('raw_response_hook') + super(StorageResponseHook, self).__init__() + + def send(self, request: "PipelineRequest") -> "PipelineResponse": + # Values could be 0 + data_stream_total = request.context.get('data_stream_total') + if data_stream_total is None: + data_stream_total = request.context.options.pop('data_stream_total', None) + download_stream_current = request.context.get('download_stream_current') + if download_stream_current is None: + download_stream_current = request.context.options.pop('download_stream_current', None) + upload_stream_current = request.context.get('upload_stream_current') + if upload_stream_current is None: + upload_stream_current = request.context.options.pop('upload_stream_current', None) + + response_callback = request.context.get('response_callback') or \ + request.context.options.pop('raw_response_hook', self._response_callback) + + response = self.next.send(request) + + will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + # Auth error could come from Bearer challenge, in which case this request will be made again + is_auth_error = response.http_response.status_code == 401 + should_update_counts = not (will_retry or is_auth_error) + + if should_update_counts and download_stream_current is not None: + download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + if data_stream_total is None: + content_range = response.http_response.headers.get('Content-Range') + if content_range: + data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + else: + data_stream_total = download_stream_current + elif should_update_counts and upload_stream_current is not None: + upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + for pipeline_obj in [request, response]: + if hasattr(pipeline_obj, 'context'): + pipeline_obj.context['data_stream_total'] = data_stream_total + pipeline_obj.context['download_stream_current'] = download_stream_current + pipeline_obj.context['upload_stream_current'] = upload_stream_current + if response_callback: + response_callback(response) + request.context['response_callback'] = response_callback + return response + + +class StorageContentValidation(SansIOHTTPPolicy): + """A simple policy that sends the given headers + with the request. + + This will overwrite any headers already defined in the request. + """ + header_name = 'Content-MD5' + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super(StorageContentValidation, self).__init__() + + @staticmethod + def get_content_md5(data): + # Since HTTP does not differentiate between no content and empty content, + # we have to perform a None check. + data = data or b"" + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, 'read'): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError("Data should be bytes or a seekable file-like object.") from exc + else: + raise ValueError("Data should be bytes or a seekable file-like object.") + + return md5.digest() + + def on_request(self, request: "PipelineRequest") -> None: + validate_content = request.context.options.pop('validate_content', False) + if validate_content and request.http_request.method != 'GET': + computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) + request.http_request.headers[self.header_name] = computed_md5 + request.context['validate_content_md5'] = computed_md5 + request.context['validate_content'] = validate_content + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + computed_md5 = request.context.get('validate_content_md5') or \ + encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) + if response.http_response.headers['content-md5'] != computed_md5: + raise AzureError(( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'."), + response=response.http_response + ) + + +class StorageRetryPolicy(HTTPPolicy): + """ + The base class for Exponential and Linear retries containing shared code. + """ + + total_retries: int + """The max number of retries.""" + connect_retries: int + """The max number of connect retries.""" + retry_read: int + """The max number of read retries.""" + retry_status: int + """The max number of status retries.""" + retry_to_secondary: bool + """Whether the secondary endpoint should be retried.""" + + def __init__(self, **kwargs: Any) -> None: + self.total_retries = kwargs.pop('retry_total', 10) + self.connect_retries = kwargs.pop('retry_connect', 3) + self.read_retries = kwargs.pop('retry_read', 3) + self.status_retries = kwargs.pop('retry_status', 3) + self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + super(StorageRetryPolicy, self).__init__() + + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + """ + A function which sets the next host location on the request, if applicable. + + :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param PipelineRequest request: A pipeline request object. + """ + if settings['hosts'] and all(settings['hosts'].values()): + url = urlparse(request.url) + # If there's more than one possible location, retry to the alternative + if settings['mode'] == LocationMode.PRIMARY: + settings['mode'] = LocationMode.SECONDARY + else: + settings['mode'] = LocationMode.PRIMARY + updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + request.url = updated.geturl() + + def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + body_position = None + if hasattr(request.http_request.body, 'read'): + try: + body_position = request.http_request.body.tell() + except (AttributeError, UnsupportedOperation): + # if body position cannot be obtained, then retries will not work + pass + options = request.context.options + return { + 'total': options.pop("retry_total", self.total_retries), + 'connect': options.pop("retry_connect", self.connect_retries), + 'read': options.pop("retry_read", self.read_retries), + 'status': options.pop("retry_status", self.status_retries), + 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), + 'mode': options.pop("location_mode", LocationMode.PRIMARY), + 'hosts': options.pop("hosts", None), + 'hook': options.pop("retry_hook", None), + 'body_position': body_position, + 'count': 0, + 'history': [] + } + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + """ Formula for computing the current backoff. + Should be calculated by child class. + + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :returns: The backoff time. + :rtype: float + """ + return 0 + + def sleep(self, settings, transport): + backoff = self.get_backoff_time(settings) + if not backoff or backoff < 0: + return + transport.sleep(backoff) + + def increment( + self, settings: Dict[str, Any], + request: "PipelineRequest", + response: Optional["PipelineResponse"] = None, + error: Optional[AzureError] = None + ) -> bool: + """Increment the retry counters. + + :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. + :param PipelineRequest request: A pipeline request object. + :param Optional[PipelineResponse] response: A pipeline response object. + :param Optional[AzureError] error: An error encountered during the request, or + None if the response was received successfully. + :returns: Whether the retry attempts are exhausted. + :rtype: bool + """ + settings['total'] -= 1 + + if error and isinstance(error, ServiceRequestError): + # Errors when we're fairly sure that the server did not receive the + # request, so it should be safe to retry. + settings['connect'] -= 1 + settings['history'].append(RequestHistory(request, error=error)) + + elif error and isinstance(error, ServiceResponseError): + # Errors that occur after the request has been started, so we should + # assume that the server began processing it. + settings['read'] -= 1 + settings['history'].append(RequestHistory(request, error=error)) + + else: + # Incrementing because of a server error like a 500 in + # status_forcelist and a the given method is in the allowlist + if response: + settings['status'] -= 1 + settings['history'].append(RequestHistory(request, http_response=response)) + + if not is_exhausted(settings): + if request.method not in ['PUT'] and settings['retry_secondary']: + self._set_next_host_location(settings, request) + + # rewind the request body if it is a stream + if request.body and hasattr(request.body, 'read'): + # no position was saved, then retry would not work + if settings['body_position'] is None: + return False + try: + # attempt to rewind the body to the initial position + request.body.seek(settings['body_position'], SEEK_SET) + except (UnsupportedOperation, ValueError): + # if body is not seekable, then retry would not work + return False + settings['count'] += 1 + return True + return False + + def send(self, request): + retries_remaining = True + response = None + retry_settings = self.configure_retries(request) + while retries_remaining: + try: + response = self.next.send(request) + if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + retries_remaining = self.increment( + retry_settings, + request=request.http_request, + response=response.http_response) + if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=response.http_response, + error=None) + self.sleep(retry_settings, request.context.transport) + continue + break + except AzureError as err: + if isinstance(err, AzureSigningError): + raise + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err) + if retries_remaining: + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err) + self.sleep(retry_settings, request.context.transport) + continue + raise err + if retry_settings['history']: + response.context['history'] = retry_settings['history'] + response.http_response.location_mode = retry_settings['mode'] + return response + + +class ExponentialRetry(StorageRetryPolicy): + """Exponential retry.""" + + initial_backoff: int + """The initial backoff interval, in seconds, for the first retry.""" + increment_base: int + """The base, in seconds, to increment the initial_backoff by after the + first retry.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, initial_backoff: int = 15, + increment_base: int = 3, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ + Constructs an Exponential retry object. The initial_backoff is used for + the first retry. Subsequent retries are retried after initial_backoff + + increment_power^retry_count seconds. + + :param int initial_backoff: + The initial backoff interval, in seconds, for the first retry. + :param int increment_base: + The base, in seconds, to increment the initial_backoff by after the + first retry. + :param int retry_total: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.initial_backoff = initial_backoff + self.increment_base = increment_base + self.random_jitter_range = random_jitter_range + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. + :returns: + A float indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: float + """ + random_generator = random.Random() + backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + random_range_end = backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class LinearRetry(StorageRetryPolicy): + """Linear retry.""" + + initial_backoff: int + """The backoff interval, in seconds, between retries.""" + random_jitter_range: int + """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" + + def __init__( + self, backoff: int = 15, + retry_total: int = 3, + retry_to_secondary: bool = False, + random_jitter_range: int = 3, + **kwargs: Any + ) -> None: + """ + Constructs a Linear retry object. + + :param int backoff: + The backoff interval, in seconds, between retries. + :param int retry_total: + The maximum number of retry attempts. + :param bool retry_to_secondary: + Whether the request should be retried to secondary, if able. This should + only be enabled of RA-GRS accounts are used and potentially stale data + can be handled. + :param int random_jitter_range: + A number in seconds which indicates a range to jitter/randomize for the back-off interval. + For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3. + """ + self.backoff = backoff + self.random_jitter_range = random_jitter_range + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + + def get_backoff_time(self, settings: Dict[str, Any]) -> float: + """ + Calculates how long to sleep before retrying. + + :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :returns: + A float indicating how long to wait before retrying the request, + or None to indicate no retry should be performed. + :rtype: float + """ + random_generator = random.Random() + # the backoff interval normally does not change, however there is the possibility + # that it was modified by accessing the property directly after initializing the object + random_range_start = self.backoff - self.random_jitter_range \ + if self.backoff > self.random_jitter_range else 0 + random_range_end = self.backoff + self.random_jitter_range + return random_generator.uniform(random_range_start, random_range_end) + + +class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): + """ Custom Bearer token credential policy for following Storage Bearer challenges """ + + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + try: + auth_header = response.http_response.headers.get("WWW-Authenticate") + challenge = StorageHttpChallenge(auth_header) + except ValueError: + return False + + scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE + self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + + return True |
