aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from typing import Optional, Tuple

from typing_extensions import Literal

from azure.ai.ml._azure_environments import _get_default_cloud_name, _get_registry_discovery_endpoint_from_metadata
from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import AzureMachineLearningWorkspaces
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import (
    BlobReferenceSASRequestDto,
    TemporaryDataReferenceRequestDto,
)
from azure.ai.ml.constants._common import REGISTRY_ASSET_ID
from azure.ai.ml.exceptions import MlException
from azure.core.exceptions import HttpResponseError

module_logger = logging.getLogger(__name__)

MFE_PATH_PREFIX = "mferp/managementfrontend"


class RegistryDiscovery:
    def __init__(
        self,
        credential: str,
        registry_name: str,
        service_client_registry_discovery_client: ServiceClientRegistryDiscovery,
        **kwargs,
    ):
        self.credential = credential
        self.registry_name = registry_name
        self.service_client_registry_discovery_client = service_client_registry_discovery_client
        self.kwargs = kwargs
        self._resource_group = None
        self._subscription_id = None
        self._base_url = None
        self.workspace_region = kwargs.get("workspace_location", None)

    def _get_registry_details(self) -> str:
        response = self.service_client_registry_discovery_client.registry_management_non_workspace.registry_management_non_workspace(  # pylint: disable=line-too-long
            self.registry_name
        )
        if self.workspace_region:
            _check_region_fqdn(self.workspace_region, response)
            self._base_url = f"https://cert-{self.workspace_region}.experiments.azureml.net/{MFE_PATH_PREFIX}"
        else:
            self._base_url = f"{response.primary_region_resource_provider_uri}{MFE_PATH_PREFIX}"
        self._subscription_id = response.subscription_id
        self._resource_group = response.resource_group

    def get_registry_service_client(self) -> AzureMachineLearningWorkspaces:
        self._get_registry_details()
        self.kwargs.pop("subscription_id", None)
        self.kwargs.pop("resource_group", None)
        service_client_10_2021_dataplanepreview = AzureMachineLearningWorkspaces(
            subscription_id=self._subscription_id,
            resource_group=self._resource_group,
            credential=self.credential,
            base_url=self._base_url,
            **self.kwargs,
        )
        return service_client_10_2021_dataplanepreview

    @property
    def subscription_id(self) -> str:
        """The subscription id of the registry.

        :return: Subscription Id
        :rtype: str
        """
        return self._subscription_id

    @property
    def resource_group(self) -> str:
        """The resource group of the registry.

        :return: Resource Group
        :rtype: str
        """
        return self._resource_group


def get_sas_uri_for_registry_asset(service_client, name, version, resource_group, registry, body) -> str:
    """Get sas_uri for registry asset.

    :param service_client: Service client
    :type service_client: AzureMachineLearningWorkspaces
    :param name: Asset name
    :type name: str
    :param version: Asset version
    :type version: str
    :param resource_group: Resource group
    :type resource_group: str
    :param registry: Registry name
    :type registry: str
    :param body: Request body
    :type body: TemporaryDataReferenceRequestDto
    :rtype: str
    """
    sas_uri = None
    try:
        res = service_client.temporary_data_references.create_or_get_temporary_data_reference(
            name=name,
            version=version,
            resource_group_name=resource_group,
            registry_name=registry,
            body=body,
        )
        sas_uri = res.blob_reference_for_consumption.credential.additional_properties["sasUri"]
    except HttpResponseError as e:
        # "Asset already exists" exception is thrown from service with error code 409, that we need to ignore
        if e.status_code == 409:
            module_logger.debug("Skipping file upload, reason:  %s", str(e.reason))
        else:
            raise e
    return sas_uri


def get_asset_body_for_registry_storage(
    registry_name: str, asset_type: str, asset_name: str, asset_version: str
) -> TemporaryDataReferenceRequestDto:
    """Get Asset body for registry.

    :param registry_name: Registry name.
    :type registry_name: str
    :param asset_type: Asset type.
    :type asset_type: str
    :param asset_name: Asset name.
    :type asset_name: str
    :param asset_version: Asset version.
    :type asset_version: str
    :return: The temporary data reference request dto
    :rtype: TemporaryDataReferenceRequestDto
    """
    body = TemporaryDataReferenceRequestDto(
        asset_id=REGISTRY_ASSET_ID.format(registry_name, asset_type, asset_name, asset_version),
        temporary_data_reference_type="TemporaryBlobReference",
    )
    return body


def get_storage_details_for_registry_assets(
    service_client: AzureMachineLearningWorkspaces,
    asset_type: str,
    asset_name: str,
    asset_version: str,
    rg_name: str,
    reg_name: str,
    uri: str,
) -> Tuple[str, Literal["NoCredentials", "SAS"]]:
    """Get storage details for registry assets.

    :param service_client: AzureMachineLearningWorkspaces service client.
    :type service_client: AzureMachineLearningWorkspaces
    :param asset_type: Asset type.
    :type asset_type: str
    :param asset_name: Asset name.
    :type asset_name: str
    :param asset_version: Asset version.
    :type asset_version: str
    :param rg_name: Resource group name.
    :type rg_name: str
    :param reg_name: Registry name
    :type reg_name: str
    :param uri: asset uri
    :type uri: str
    :return: A 2-tuple of a URI and a string. Either:
      * A blob uri and "NoCredentials"
      * A sas URI and "SAS"
    :rtype: Tuple[str, Literal["NoCredentials", "SAS"]]
    """
    body = BlobReferenceSASRequestDto(
        asset_id=REGISTRY_ASSET_ID.format(reg_name, asset_type, asset_name, asset_version),
        blob_uri=uri,
    )
    sas_uri = service_client.data_references.get_blob_reference_sas(
        name=asset_name,
        version=asset_version,
        resource_group_name=rg_name,
        registry_name=reg_name,
        body=body,
    )
    if sas_uri.blob_reference_for_consumption.credential.credential_type == "no_credentials":
        return sas_uri.blob_reference_for_consumption.blob_uri, "NoCredentials"

    return (
        sas_uri.blob_reference_for_consumption.credential.additional_properties["sasUri"],
        "SAS",
    )


def get_registry_client(credential, registry_name, workspace_location: Optional[str] = None, **kwargs):
    base_url = _get_registry_discovery_endpoint_from_metadata(_get_default_cloud_name())
    kwargs.pop("base_url", None)

    service_client_registry_discovery_client = ServiceClientRegistryDiscovery(
        credential=credential, base_url=base_url, **kwargs
    )
    if workspace_location:
        workspace_kwargs = {"workspace_location": workspace_location}
        kwargs.update(workspace_kwargs)

    registry_discovery = RegistryDiscovery(
        credential, registry_name, service_client_registry_discovery_client, **kwargs
    )
    service_client_10_2021_dataplanepreview = registry_discovery.get_registry_service_client()
    subscription_id = registry_discovery.subscription_id
    resource_group_name = registry_discovery.resource_group
    return service_client_10_2021_dataplanepreview, resource_group_name, subscription_id


def _check_region_fqdn(workspace_region, response):
    if workspace_region in response.additional_properties["registryFqdns"].keys():
        return
    regions = list(response.additional_properties["registryFqdns"].keys())
    msg = f"Workspace region {workspace_region} not supported by the \
            registry {response.registry_name} regions {regions}"
    raise MlException(message=msg, no_personal_data_message=msg)