aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from typing import Dict, Iterable, Optional, cast

from azure.ai.ml._scope_dependent_operations import OperationScope
from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
from azure.ai.ml._utils._arm_id_utils import AzureResourceId
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml._utils._virtual_cluster_utils import (
    CrossRegionIndexEntitiesRequest,
    IndexEntitiesRequest,
    IndexEntitiesRequestFilter,
    IndexEntitiesRequestOrder,
    IndexEntitiesRequestOrderDirection,
    IndexServiceAPIs,
    index_entity_response_to_job,
)
from azure.ai.ml._utils.azure_resource_utils import (
    get_generic_resource_by_id,
    get_virtual_cluster_by_name,
    get_virtual_clusters_from_subscriptions,
)
from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope
from azure.ai.ml.entities import Job
from azure.ai.ml.exceptions import UserErrorException, ValidationException
from azure.core.credentials import TokenCredential
from azure.core.tracing.decorator import distributed_trace

ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger


class VirtualClusterOperations:
    """VirtualClusterOperations.

    You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
    for you and attaches it as an attribute.
    """

    def __init__(
        self,
        operation_scope: OperationScope,
        credentials: TokenCredential,
        *,
        _service_client_kwargs: Dict,
        **kwargs: Dict,
    ):
        ops_logger.update_filter()
        self._resource_group_name = operation_scope.resource_group_name
        self._subscription_id = operation_scope.subscription_id
        self._credentials = credentials
        self._init_kwargs = kwargs
        self._service_client_kwargs = _service_client_kwargs

        """A (mostly) autogenerated rest client for the index service.

        TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp
        """
        self._index_service = IndexServiceAPIs(
            credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs
        )

    @distributed_trace
    @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI)
    def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]:
        """List virtual clusters a user has access to.

        :keyword scope: scope of the listing, "subscription" or None, defaults to None.
            If None, list virtual clusters across all subscriptions a customer has access to.
        :paramtype scope: str
        :return: An iterator like instance of dictionaries.
        :rtype: ~azure.core.paging.ItemPaged[Dict]
        """

        if scope is None:
            subscription_list = None
        elif scope.lower() == Scope.SUBSCRIPTION:
            subscription_list = [self._subscription_id]
        else:
            message = f"Invalid scope: {scope}. Valid values are 'subscription' or None."
            raise UserErrorException(message=message, no_personal_data_message=message)

        try:
            return cast(
                Iterable[Dict],
                get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list),
            )
        except ImportError as e:
            raise UserErrorException(
                message="Met ImportError when trying to list virtual clusters. "
                "Please install azure-mgmt-resource to enable this feature; "
                "and please install azure-mgmt-resource to enable listing virtual clusters "
                "across all subscriptions a customer has access to."
            ) from e

    @distributed_trace
    @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI)
    def list_jobs(self, name: str) -> Iterable[Job]:
        """List of jobs that target the virtual cluster

        :param name: Name of virtual cluster
        :type name: str
        :return: An iterable of jobs.
        :rtype: Iterable[Job]
        """

        def make_id(entity_type: str) -> str:
            return str(
                LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format(
                    self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name
                )
            )

        # List of virtual cluster ids to match
        # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change
        vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()]

        filters = [
            IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]),
            IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]),
            IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids),
        ]
        order = [
            IndexEntitiesRequestOrder(
                field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC
            )
        ]
        index_entities_request = IndexEntitiesRequest(filters=filters, order=order)

        # cspell:ignore entites
        return cast(
            Iterable[Job],
            self._index_service.index_entities.get_entites_cross_region(
                body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request),
                cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs],
            ),
        )

    @distributed_trace
    @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI)
    def get(self, name: str) -> Dict:
        """
        Get a virtual cluster resource. If name is provided, the virtual cluster
        with the name in the subscription and resource group of the MLClient object
        will be returned. If an ARM id is provided, a virtual cluster with the id will be returned.

        :param name: Name or ARM ID of the virtual cluster.
        :type name: str
        :return: Virtual cluster object
        :rtype: Dict
        """

        try:
            arm_id = AzureResourceId(name)
            sub_id = arm_id.subscription_id

            return cast(
                Dict,
                get_generic_resource_by_id(
                    arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview"
                ),
            )
        except ValidationException:
            return cast(
                Dict,
                get_virtual_cluster_by_name(
                    name=name,
                    resource_group=self._resource_group_name,
                    subscription_id=self._subscription_id,
                    credential=self._credentials,
                ),
            )