about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py222
1 files changed, 222 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py
new file mode 100644
index 00000000..09c5d1ae
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_registry_utils.py
@@ -0,0 +1,222 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Optional, Tuple
+
+from typing_extensions import Literal
+
+from azure.ai.ml._azure_environments import _get_default_cloud_name, _get_registry_discovery_endpoint_from_metadata
+from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import AzureMachineLearningWorkspaces
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import (
+    BlobReferenceSASRequestDto,
+    TemporaryDataReferenceRequestDto,
+)
+from azure.ai.ml.constants._common import REGISTRY_ASSET_ID
+from azure.ai.ml.exceptions import MlException
+from azure.core.exceptions import HttpResponseError
+
+module_logger = logging.getLogger(__name__)
+
+MFE_PATH_PREFIX = "mferp/managementfrontend"
+
+
+class RegistryDiscovery:
+    def __init__(
+        self,
+        credential: str,
+        registry_name: str,
+        service_client_registry_discovery_client: ServiceClientRegistryDiscovery,
+        **kwargs,
+    ):
+        self.credential = credential
+        self.registry_name = registry_name
+        self.service_client_registry_discovery_client = service_client_registry_discovery_client
+        self.kwargs = kwargs
+        self._resource_group = None
+        self._subscription_id = None
+        self._base_url = None
+        self.workspace_region = kwargs.get("workspace_location", None)
+
+    def _get_registry_details(self) -> str:
+        response = self.service_client_registry_discovery_client.registry_management_non_workspace.registry_management_non_workspace(  # pylint: disable=line-too-long
+            self.registry_name
+        )
+        if self.workspace_region:
+            _check_region_fqdn(self.workspace_region, response)
+            self._base_url = f"https://cert-{self.workspace_region}.experiments.azureml.net/{MFE_PATH_PREFIX}"
+        else:
+            self._base_url = f"{response.primary_region_resource_provider_uri}{MFE_PATH_PREFIX}"
+        self._subscription_id = response.subscription_id
+        self._resource_group = response.resource_group
+
+    def get_registry_service_client(self) -> AzureMachineLearningWorkspaces:
+        self._get_registry_details()
+        self.kwargs.pop("subscription_id", None)
+        self.kwargs.pop("resource_group", None)
+        service_client_10_2021_dataplanepreview = AzureMachineLearningWorkspaces(
+            subscription_id=self._subscription_id,
+            resource_group=self._resource_group,
+            credential=self.credential,
+            base_url=self._base_url,
+            **self.kwargs,
+        )
+        return service_client_10_2021_dataplanepreview
+
+    @property
+    def subscription_id(self) -> str:
+        """The subscription id of the registry.
+
+        :return: Subscription Id
+        :rtype: str
+        """
+        return self._subscription_id
+
+    @property
+    def resource_group(self) -> str:
+        """The resource group of the registry.
+
+        :return: Resource Group
+        :rtype: str
+        """
+        return self._resource_group
+
+
+def get_sas_uri_for_registry_asset(service_client, name, version, resource_group, registry, body) -> str:
+    """Get sas_uri for registry asset.
+
+    :param service_client: Service client
+    :type service_client: AzureMachineLearningWorkspaces
+    :param name: Asset name
+    :type name: str
+    :param version: Asset version
+    :type version: str
+    :param resource_group: Resource group
+    :type resource_group: str
+    :param registry: Registry name
+    :type registry: str
+    :param body: Request body
+    :type body: TemporaryDataReferenceRequestDto
+    :rtype: str
+    """
+    sas_uri = None
+    try:
+        res = service_client.temporary_data_references.create_or_get_temporary_data_reference(
+            name=name,
+            version=version,
+            resource_group_name=resource_group,
+            registry_name=registry,
+            body=body,
+        )
+        sas_uri = res.blob_reference_for_consumption.credential.additional_properties["sasUri"]
+    except HttpResponseError as e:
+        # "Asset already exists" exception is thrown from service with error code 409, that we need to ignore
+        if e.status_code == 409:
+            module_logger.debug("Skipping file upload, reason:  %s", str(e.reason))
+        else:
+            raise e
+    return sas_uri
+
+
+def get_asset_body_for_registry_storage(
+    registry_name: str, asset_type: str, asset_name: str, asset_version: str
+) -> TemporaryDataReferenceRequestDto:
+    """Get Asset body for registry.
+
+    :param registry_name: Registry name.
+    :type registry_name: str
+    :param asset_type: Asset type.
+    :type asset_type: str
+    :param asset_name: Asset name.
+    :type asset_name: str
+    :param asset_version: Asset version.
+    :type asset_version: str
+    :return: The temporary data reference request dto
+    :rtype: TemporaryDataReferenceRequestDto
+    """
+    body = TemporaryDataReferenceRequestDto(
+        asset_id=REGISTRY_ASSET_ID.format(registry_name, asset_type, asset_name, asset_version),
+        temporary_data_reference_type="TemporaryBlobReference",
+    )
+    return body
+
+
+def get_storage_details_for_registry_assets(
+    service_client: AzureMachineLearningWorkspaces,
+    asset_type: str,
+    asset_name: str,
+    asset_version: str,
+    rg_name: str,
+    reg_name: str,
+    uri: str,
+) -> Tuple[str, Literal["NoCredentials", "SAS"]]:
+    """Get storage details for registry assets.
+
+    :param service_client: AzureMachineLearningWorkspaces service client.
+    :type service_client: AzureMachineLearningWorkspaces
+    :param asset_type: Asset type.
+    :type asset_type: str
+    :param asset_name: Asset name.
+    :type asset_name: str
+    :param asset_version: Asset version.
+    :type asset_version: str
+    :param rg_name: Resource group name.
+    :type rg_name: str
+    :param reg_name: Registry name
+    :type reg_name: str
+    :param uri: asset uri
+    :type uri: str
+    :return: A 2-tuple of a URI and a string. Either:
+      * A blob uri and "NoCredentials"
+      * A sas URI and "SAS"
+    :rtype: Tuple[str, Literal["NoCredentials", "SAS"]]
+    """
+    body = BlobReferenceSASRequestDto(
+        asset_id=REGISTRY_ASSET_ID.format(reg_name, asset_type, asset_name, asset_version),
+        blob_uri=uri,
+    )
+    sas_uri = service_client.data_references.get_blob_reference_sas(
+        name=asset_name,
+        version=asset_version,
+        resource_group_name=rg_name,
+        registry_name=reg_name,
+        body=body,
+    )
+    if sas_uri.blob_reference_for_consumption.credential.credential_type == "no_credentials":
+        return sas_uri.blob_reference_for_consumption.blob_uri, "NoCredentials"
+
+    return (
+        sas_uri.blob_reference_for_consumption.credential.additional_properties["sasUri"],
+        "SAS",
+    )
+
+
+def get_registry_client(credential, registry_name, workspace_location: Optional[str] = None, **kwargs):
+    base_url = _get_registry_discovery_endpoint_from_metadata(_get_default_cloud_name())
+    kwargs.pop("base_url", None)
+
+    service_client_registry_discovery_client = ServiceClientRegistryDiscovery(
+        credential=credential, base_url=base_url, **kwargs
+    )
+    if workspace_location:
+        workspace_kwargs = {"workspace_location": workspace_location}
+        kwargs.update(workspace_kwargs)
+
+    registry_discovery = RegistryDiscovery(
+        credential, registry_name, service_client_registry_discovery_client, **kwargs
+    )
+    service_client_10_2021_dataplanepreview = registry_discovery.get_registry_service_client()
+    subscription_id = registry_discovery.subscription_id
+    resource_group_name = registry_discovery.resource_group
+    return service_client_10_2021_dataplanepreview, resource_group_name, subscription_id
+
+
+def _check_region_fqdn(workspace_region, response):
+    if workspace_region in response.additional_properties["registryFqdns"].keys():
+        return
+    regions = list(response.additional_properties["registryFqdns"].keys())
+    msg = f"Workspace region {workspace_region} not supported by the \
+            registry {response.registry_name} regions {regions}"
+    raise MlException(message=msg, no_personal_data_message=msg)