diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py | 458 |
1 files changed, 458 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py new file mode 100644 index 00000000..ceb75bf0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/storage/filedatalake/_shared/base_client.py @@ -0,0 +1,458 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +import uuid +from typing import ( + Any, + cast, + Dict, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, +) +from urllib.parse import parse_qs, quote + +from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential +from azure.core.exceptions import HttpResponseError +from azure.core.pipeline import Pipeline +from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.policies import ( + AzureSasCredentialPolicy, + ContentDecodePolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, + ProxyPolicy, + RedirectPolicy, + UserAgentPolicy, +) + +from .authentication import SharedKeyCredentialPolicy +from .constants import CONNECTION_TIMEOUT, DEFAULT_OAUTH_SCOPE, READ_TIMEOUT, SERVICE_HOST_BASE, STORAGE_OAUTH_SCOPE +from .models import LocationMode, StorageConfiguration +from .policies import ( + ExponentialRetry, + QueueMessagePolicy, + StorageBearerTokenCredentialPolicy, + StorageContentValidation, + StorageHeadersPolicy, + StorageHosts, + StorageLoggingPolicy, + StorageRequestHook, + StorageResponseHook, +) +from .request_handlers import serialize_batch_body, _get_batch_request_delimiter +from .response_handlers import PartialBatchErrorException, process_storage_error +from .shared_access_signature import QueryStringConstants +from .._version import VERSION +from .._shared_access_signature import _is_credential_sastoken + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.pipeline.transport import HttpRequest, HttpResponse # pylint: disable=C4756 + +_LOGGER = logging.getLogger(__name__) +_SERVICE_PARAMS = { + "blob": {"primary": "BLOBENDPOINT", "secondary": "BLOBSECONDARYENDPOINT"}, + "queue": {"primary": "QUEUEENDPOINT", "secondary": "QUEUESECONDARYENDPOINT"}, + "file": {"primary": "FILEENDPOINT", "secondary": "FILESECONDARYENDPOINT"}, + "dfs": {"primary": "BLOBENDPOINT", "secondary": "BLOBENDPOINT"}, +} + + +class StorageAccountHostsMixin(object): + _client: Any + def __init__( + self, + parsed_url: Any, + service: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: + self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) + self._hosts = kwargs.get("_hosts") + self.scheme = parsed_url.scheme + self._is_localhost = False + + if service not in ["blob", "queue", "file-share", "dfs"]: + raise ValueError(f"Invalid service: {service}") + service_name = service.split('-')[0] + account = parsed_url.netloc.split(f".{service_name}.core.") + + self.account_name = account[0] if len(account) > 1 else None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self._is_localhost = True + self.account_name = parsed_url.path.strip("/") + + self.credential = _format_shared_key_credential(self.account_name, credential) + if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): + raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None + if hasattr(self.credential, "account_name"): + self.account_name = self.credential.account_name + secondary_hostname = f"{self.credential.account_name}-secondary.{service_name}.{SERVICE_HOST_BASE}" + + if not self._hosts: + if len(account) > 1: + secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") + if kwargs.get("secondary_hostname"): + secondary_hostname = kwargs["secondary_hostname"] + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} + + self._sdk_moniker = f"storage-{service}/{VERSION}" + self._config, self._pipeline = self._create_pipeline(self.credential, sdk_moniker=self._sdk_moniker, **kwargs) + + def __enter__(self): + self._client.__enter__() + return self + + def __exit__(self, *args): + self._client.__exit__(*args) + + def close(self): + """ This method is to close the sockets opened by the client. + It need not be used when using with a context manager. + """ + self._client.close() + + @property + def url(self): + """The full endpoint URL to this entity, including SAS token if used. + + This could be either the primary endpoint, + or the secondary endpoint depending on the current :func:`location_mode`. + :returns: The full endpoint URL to this entity, including SAS token if used. + :rtype: str + """ + return self._format_url(self._hosts[self._location_mode]) + + @property + def primary_endpoint(self): + """The full primary endpoint URL. + + :rtype: str + """ + return self._format_url(self._hosts[LocationMode.PRIMARY]) + + @property + def primary_hostname(self): + """The hostname of the primary endpoint. + + :rtype: str + """ + return self._hosts[LocationMode.PRIMARY] + + @property + def secondary_endpoint(self): + """The full secondary endpoint URL if configured. + + If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + + :rtype: str + :raise ValueError: + """ + if not self._hosts[LocationMode.SECONDARY]: + raise ValueError("No secondary host configured.") + return self._format_url(self._hosts[LocationMode.SECONDARY]) + + @property + def secondary_hostname(self): + """The hostname of the secondary endpoint. + + If not available this will be None. To explicitly specify a secondary hostname, use the optional + `secondary_hostname` keyword argument on instantiation. + + :rtype: Optional[str] + """ + return self._hosts[LocationMode.SECONDARY] + + @property + def location_mode(self): + """The location mode that the client is currently using. + + By default this will be "primary". Options include "primary" and "secondary". + + :rtype: str + """ + + return self._location_mode + + @location_mode.setter + def location_mode(self, value): + if self._hosts.get(value): + self._location_mode = value + self._client._config.url = self.url # pylint: disable=protected-access + else: + raise ValueError(f"No host URL for location mode: {value}") + + @property + def api_version(self): + """The version of the Storage API used for requests. + + :rtype: str + """ + return self._client._config.version # pylint: disable=protected-access + + def _format_query_string( + self, sas_token: Optional[str], + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + snapshot: Optional[str] = None, + share_snapshot: Optional[str] = None + ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + query_str = "?" + if snapshot: + query_str += f"snapshot={snapshot}&" + if share_snapshot: + query_str += f"sharesnapshot={share_snapshot}&" + if sas_token and isinstance(credential, AzureSasCredential): + raise ValueError( + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + if _is_credential_sastoken(credential): + credential = cast(str, credential) + query_str += credential.lstrip("?") + credential = None + elif sas_token: + query_str += sas_token + return query_str.rstrip("?&"), credential + + def _create_pipeline( + self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Tuple[StorageConfiguration, Pipeline]: + self._credential_policy: Any = None + if hasattr(credential, "get_token"): + if kwargs.get('audience'): + audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + else: + audience = STORAGE_OAUTH_SCOPE + self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) + elif isinstance(credential, SharedKeyCredentialPolicy): + self._credential_policy = credential + elif isinstance(credential, AzureSasCredential): + self._credential_policy = AzureSasCredentialPolicy(credential) + elif credential is not None: + raise TypeError(f"Unsupported credential: {type(credential)}") + + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") + kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) + kwargs.setdefault("read_timeout", READ_TIMEOUT) + if not transport: + transport = RequestsTransport(**kwargs) + policies = [ + QueueMessagePolicy(), + config.proxy_policy, + config.user_agent_policy, + StorageContentValidation(), + ContentDecodePolicy(response_encoding="utf-8"), + RedirectPolicy(**kwargs), + StorageHosts(hosts=self._hosts, **kwargs), + config.retry_policy, + config.headers_policy, + StorageRequestHook(**kwargs), + self._credential_policy, + config.logging_policy, + StorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs) + ] + if kwargs.get("_additional_pipeline_policies"): + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, Pipeline(transport, policies=policies) + + def _batch_send( + self, + *reqs: "HttpRequest", + **kwargs: Any + ) -> Iterator["HttpResponse"]: + """Given a series of request, do a Storage batch call. + + :param HttpRequest reqs: A collection of HttpRequest objects. + :returns: An iterator of HttpResponse objects. + :rtype: Iterator[HttpResponse] + """ + # Pop it here, so requests doesn't feel bad about additional kwarg + raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) + batch_id = str(uuid.uuid1()) + + request = self._client._client.post( # pylint: disable=protected-access + url=( + f'{self.scheme}://{self.primary_hostname}/' + f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" + f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" + ), + headers={ + 'x-ms-version': self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) + } + ) + + policies = [StorageHeadersPolicy()] + if self._credential_policy: + policies.append(self._credential_policy) + + request.set_multipart_mixed( + *reqs, + policies=policies, + enforce_https=False + ) + + Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access + body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) + request.set_bytes_body(body) + + temp = request.multipart_mixed_info + request.multipart_mixed_info = None + pipeline_response = self._pipeline.run( + request, **kwargs + ) + response = pipeline_response.http_response + request.multipart_mixed_info = temp + + try: + if response.status_code not in [202]: + raise HttpResponseError(response=response) + parts = response.parts() + if raise_on_any_failure: + parts = list(response.parts()) + if any(p for p in parts if not 200 <= p.status_code < 300): + error = PartialBatchErrorException( + message="There is a partial failure in the batch operation.", + response=response, parts=parts + ) + raise error + return iter(parts) + return parts # type: ignore [no-any-return] + except HttpResponseError as error: + process_storage_error(error) + + +class TransportWrapper(HttpTransport): + """Wrapper class that ensures that an inner client created + by a `get_client` method does not close the outer transport for the parent + when used in a context manager. + """ + def __init__(self, transport): + self._transport = transport + + def send(self, request, **kwargs): + return self._transport.send(request, **kwargs) + + def open(self): + pass + + def close(self): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +def _format_shared_key_credential( + account_name: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long +) -> Any: + if isinstance(credential, str): + if not account_name: + raise ValueError("Unable to determine account name for shared key credential.") + credential = {"account_name": account_name, "account_key": credential} + if isinstance(credential, dict): + if "account_name" not in credential: + raise ValueError("Shared key credential missing 'account_name") + if "account_key" not in credential: + raise ValueError("Shared key credential missing 'account_key") + return SharedKeyCredentialPolicy(**credential) + if isinstance(credential, AzureNamedKeyCredential): + return SharedKeyCredentialPolicy(credential.named_key.name, credential.named_key.key) + return credential + + +def parse_connection_str( + conn_str: str, + credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], + service: str +) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + conn_str = conn_str.rstrip(";") + conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings_list): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict((key.upper(), val) for key, val in conn_settings_list) + endpoints = _SERVICE_PARAMS[service] + primary = None + secondary = None + if not credential: + try: + credential = {"account_name": conn_settings["ACCOUNTNAME"], "account_key": conn_settings["ACCOUNTKEY"]} + except KeyError: + credential = conn_settings.get("SHAREDACCESSSIGNATURE") + if endpoints["primary"] in conn_settings: + primary = conn_settings[endpoints["primary"]] + if endpoints["secondary"] in conn_settings: + secondary = conn_settings[endpoints["secondary"]] + else: + if endpoints["secondary"] in conn_settings: + raise ValueError("Connection string specifies only secondary endpoint.") + try: + primary =( + f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" + f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + secondary = ( + f"{conn_settings['ACCOUNTNAME']}-secondary." + f"{service}.{conn_settings['ENDPOINTSUFFIX']}" + ) + except KeyError: + pass + + if not primary: + try: + primary = ( + f"https://{conn_settings['ACCOUNTNAME']}." + f"{service}.{conn_settings.get('ENDPOINTSUFFIX', SERVICE_HOST_BASE)}" + ) + except KeyError as exc: + raise ValueError("Connection string missing required connection details.") from exc + if service == "dfs": + primary = primary.replace(".blob.", ".dfs.") + if secondary: + secondary = secondary.replace(".blob.", ".dfs.") + return primary, secondary, credential + + +def create_configuration(**kwargs: Any) -> StorageConfiguration: + # Backwards compatibility if someone is not passing sdk_moniker + if not kwargs.get("sdk_moniker"): + kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" + config = StorageConfiguration(**kwargs) + config.headers_policy = StorageHeadersPolicy(**kwargs) + config.user_agent_policy = UserAgentPolicy(**kwargs) + config.retry_policy = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) + config.logging_policy = StorageLoggingPolicy(**kwargs) + config.proxy_policy = ProxyPolicy(**kwargs) + return config + + +def parse_query(query_str: str) -> Tuple[Optional[str], Optional[str]]: + sas_values = QueryStringConstants.to_list() + parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()} + sas_params = [f"{k}={quote(v, safe='')}" for k, v in parsed_query.items() if k in sas_values] + sas_token = None + if sas_params: + sas_token = "&".join(sas_params) + + snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot") + return snapshot, sas_token |
