diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py new file mode 100644 index 00000000..09c5d1ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py @@ -0,0 +1,222 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import Optional, Tuple + +from typing_extensions import Literal + +from azure.ai.ml._azure_environments import _get_default_cloud_name, _get_registry_discovery_endpoint_from_metadata +from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import AzureMachineLearningWorkspaces +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import ( + BlobReferenceSASRequestDto, + TemporaryDataReferenceRequestDto, +) +from azure.ai.ml.constants._common import REGISTRY_ASSET_ID +from azure.ai.ml.exceptions import MlException +from azure.core.exceptions import HttpResponseError + +module_logger = logging.getLogger(__name__) + +MFE_PATH_PREFIX = "mferp/managementfrontend" + + +class RegistryDiscovery: + def __init__( + self, + credential: str, + registry_name: str, + service_client_registry_discovery_client: ServiceClientRegistryDiscovery, + **kwargs, + ): + self.credential = credential + self.registry_name = registry_name + self.service_client_registry_discovery_client = service_client_registry_discovery_client + self.kwargs = kwargs + self._resource_group = None + self._subscription_id = None + self._base_url = None + self.workspace_region = kwargs.get("workspace_location", None) + + def _get_registry_details(self) -> str: + response = self.service_client_registry_discovery_client.registry_management_non_workspace.registry_management_non_workspace( # pylint: disable=line-too-long + self.registry_name + ) + if self.workspace_region: + _check_region_fqdn(self.workspace_region, response) + self._base_url = f"https://cert-{self.workspace_region}.experiments.azureml.net/{MFE_PATH_PREFIX}" + else: + self._base_url = f"{response.primary_region_resource_provider_uri}{MFE_PATH_PREFIX}" + self._subscription_id = response.subscription_id + self._resource_group = response.resource_group + + def get_registry_service_client(self) -> AzureMachineLearningWorkspaces: + self._get_registry_details() + self.kwargs.pop("subscription_id", None) + self.kwargs.pop("resource_group", None) + service_client_10_2021_dataplanepreview = AzureMachineLearningWorkspaces( + subscription_id=self._subscription_id, + resource_group=self._resource_group, + credential=self.credential, + base_url=self._base_url, + **self.kwargs, + ) + return service_client_10_2021_dataplanepreview + + @property + def subscription_id(self) -> str: + """The subscription id of the registry. + + :return: Subscription Id + :rtype: str + """ + return self._subscription_id + + @property + def resource_group(self) -> str: + """The resource group of the registry. + + :return: Resource Group + :rtype: str + """ + return self._resource_group + + +def get_sas_uri_for_registry_asset(service_client, name, version, resource_group, registry, body) -> str: + """Get sas_uri for registry asset. + + :param service_client: Service client + :type service_client: AzureMachineLearningWorkspaces + :param name: Asset name + :type name: str + :param version: Asset version + :type version: str + :param resource_group: Resource group + :type resource_group: str + :param registry: Registry name + :type registry: str + :param body: Request body + :type body: TemporaryDataReferenceRequestDto + :rtype: str + """ + sas_uri = None + try: + res = service_client.temporary_data_references.create_or_get_temporary_data_reference( + name=name, + version=version, + resource_group_name=resource_group, + registry_name=registry, + body=body, + ) + sas_uri = res.blob_reference_for_consumption.credential.additional_properties["sasUri"] + except HttpResponseError as e: + # "Asset already exists" exception is thrown from service with error code 409, that we need to ignore + if e.status_code == 409: + module_logger.debug("Skipping file upload, reason: %s", str(e.reason)) + else: + raise e + return sas_uri + + +def get_asset_body_for_registry_storage( + registry_name: str, asset_type: str, asset_name: str, asset_version: str +) -> TemporaryDataReferenceRequestDto: + """Get Asset body for registry. + + :param registry_name: Registry name. + :type registry_name: str + :param asset_type: Asset type. + :type asset_type: str + :param asset_name: Asset name. + :type asset_name: str + :param asset_version: Asset version. + :type asset_version: str + :return: The temporary data reference request dto + :rtype: TemporaryDataReferenceRequestDto + """ + body = TemporaryDataReferenceRequestDto( + asset_id=REGISTRY_ASSET_ID.format(registry_name, asset_type, asset_name, asset_version), + temporary_data_reference_type="TemporaryBlobReference", + ) + return body + + +def get_storage_details_for_registry_assets( + service_client: AzureMachineLearningWorkspaces, + asset_type: str, + asset_name: str, + asset_version: str, + rg_name: str, + reg_name: str, + uri: str, +) -> Tuple[str, Literal["NoCredentials", "SAS"]]: + """Get storage details for registry assets. + + :param service_client: AzureMachineLearningWorkspaces service client. + :type service_client: AzureMachineLearningWorkspaces + :param asset_type: Asset type. + :type asset_type: str + :param asset_name: Asset name. + :type asset_name: str + :param asset_version: Asset version. + :type asset_version: str + :param rg_name: Resource group name. + :type rg_name: str + :param reg_name: Registry name + :type reg_name: str + :param uri: asset uri + :type uri: str + :return: A 2-tuple of a URI and a string. Either: + * A blob uri and "NoCredentials" + * A sas URI and "SAS" + :rtype: Tuple[str, Literal["NoCredentials", "SAS"]] + """ + body = BlobReferenceSASRequestDto( + asset_id=REGISTRY_ASSET_ID.format(reg_name, asset_type, asset_name, asset_version), + blob_uri=uri, + ) + sas_uri = service_client.data_references.get_blob_reference_sas( + name=asset_name, + version=asset_version, + resource_group_name=rg_name, + registry_name=reg_name, + body=body, + ) + if sas_uri.blob_reference_for_consumption.credential.credential_type == "no_credentials": + return sas_uri.blob_reference_for_consumption.blob_uri, "NoCredentials" + + return ( + sas_uri.blob_reference_for_consumption.credential.additional_properties["sasUri"], + "SAS", + ) + + +def get_registry_client(credential, registry_name, workspace_location: Optional[str] = None, **kwargs): + base_url = _get_registry_discovery_endpoint_from_metadata(_get_default_cloud_name()) + kwargs.pop("base_url", None) + + service_client_registry_discovery_client = ServiceClientRegistryDiscovery( + credential=credential, base_url=base_url, **kwargs + ) + if workspace_location: + workspace_kwargs = {"workspace_location": workspace_location} + kwargs.update(workspace_kwargs) + + registry_discovery = RegistryDiscovery( + credential, registry_name, service_client_registry_discovery_client, **kwargs + ) + service_client_10_2021_dataplanepreview = registry_discovery.get_registry_service_client() + subscription_id = registry_discovery.subscription_id + resource_group_name = registry_discovery.resource_group + return service_client_10_2021_dataplanepreview, resource_group_name, subscription_id + + +def _check_region_fqdn(workspace_region, response): + if workspace_region in response.additional_properties["registryFqdns"].keys(): + return + regions = list(response.additional_properties["registryFqdns"].keys()) + msg = f"Workspace region {workspace_region} not supported by the \ + registry {response.registry_name} regions {regions}" + raise MlException(message=msg, no_personal_data_message=msg) |