# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from typing import Dict, Iterable, Optional, cast from azure.ai.ml._scope_dependent_operations import OperationScope from azure.ai.ml._telemetry import ActivityType, monitor_with_activity from azure.ai.ml._utils._arm_id_utils import AzureResourceId from azure.ai.ml._utils._logger_utils import OpsLogger from azure.ai.ml._utils._virtual_cluster_utils import ( CrossRegionIndexEntitiesRequest, IndexEntitiesRequest, IndexEntitiesRequestFilter, IndexEntitiesRequestOrder, IndexEntitiesRequestOrderDirection, IndexServiceAPIs, index_entity_response_to_job, ) from azure.ai.ml._utils.azure_resource_utils import ( get_generic_resource_by_id, get_virtual_cluster_by_name, get_virtual_clusters_from_subscriptions, ) from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope from azure.ai.ml.entities import Job from azure.ai.ml.exceptions import UserErrorException, ValidationException from azure.core.credentials import TokenCredential from azure.core.tracing.decorator import distributed_trace ops_logger = OpsLogger(__name__) module_logger = ops_logger.module_logger class VirtualClusterOperations: """VirtualClusterOperations. You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute. """ def __init__( self, operation_scope: OperationScope, credentials: TokenCredential, *, _service_client_kwargs: Dict, **kwargs: Dict, ): ops_logger.update_filter() self._resource_group_name = operation_scope.resource_group_name self._subscription_id = operation_scope.subscription_id self._credentials = credentials self._init_kwargs = kwargs self._service_client_kwargs = _service_client_kwargs """A (mostly) autogenerated rest client for the index service. TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp """ self._index_service = IndexServiceAPIs( credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs ) @distributed_trace @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI) def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]: """List virtual clusters a user has access to. :keyword scope: scope of the listing, "subscription" or None, defaults to None. If None, list virtual clusters across all subscriptions a customer has access to. :paramtype scope: str :return: An iterator like instance of dictionaries. :rtype: ~azure.core.paging.ItemPaged[Dict] """ if scope is None: subscription_list = None elif scope.lower() == Scope.SUBSCRIPTION: subscription_list = [self._subscription_id] else: message = f"Invalid scope: {scope}. Valid values are 'subscription' or None." raise UserErrorException(message=message, no_personal_data_message=message) try: return cast( Iterable[Dict], get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list), ) except ImportError as e: raise UserErrorException( message="Met ImportError when trying to list virtual clusters. " "Please install azure-mgmt-resource to enable this feature; " "and please install azure-mgmt-resource to enable listing virtual clusters " "across all subscriptions a customer has access to." ) from e @distributed_trace @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI) def list_jobs(self, name: str) -> Iterable[Job]: """List of jobs that target the virtual cluster :param name: Name of virtual cluster :type name: str :return: An iterable of jobs. :rtype: Iterable[Job] """ def make_id(entity_type: str) -> str: return str( LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format( self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name ) ) # List of virtual cluster ids to match # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()] filters = [ IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]), IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]), IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids), ] order = [ IndexEntitiesRequestOrder( field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC ) ] index_entities_request = IndexEntitiesRequest(filters=filters, order=order) # cspell:ignore entites return cast( Iterable[Job], self._index_service.index_entities.get_entites_cross_region( body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request), cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs], ), ) @distributed_trace @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI) def get(self, name: str) -> Dict: """ Get a virtual cluster resource. If name is provided, the virtual cluster with the name in the subscription and resource group of the MLClient object will be returned. If an ARM id is provided, a virtual cluster with the id will be returned. :param name: Name or ARM ID of the virtual cluster. :type name: str :return: Virtual cluster object :rtype: Dict """ try: arm_id = AzureResourceId(name) sub_id = arm_id.subscription_id return cast( Dict, get_generic_resource_by_id( arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview" ), ) except ValidationException: return cast( Dict, get_virtual_cluster_by_name( name=name, resource_group=self._resource_group_name, subscription_id=self._subscription_id, credential=self._credentials, ), )