aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py174
1 files changed, 174 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
new file mode 100644
index 00000000..256664ad
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
@@ -0,0 +1,174 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Iterable, Optional, cast
+
+from azure.ai.ml._scope_dependent_operations import OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AzureResourceId
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._virtual_cluster_utils import (
+ CrossRegionIndexEntitiesRequest,
+ IndexEntitiesRequest,
+ IndexEntitiesRequestFilter,
+ IndexEntitiesRequestOrder,
+ IndexEntitiesRequestOrderDirection,
+ IndexServiceAPIs,
+ index_entity_response_to_job,
+)
+from azure.ai.ml._utils.azure_resource_utils import (
+ get_generic_resource_by_id,
+ get_virtual_cluster_by_name,
+ get_virtual_clusters_from_subscriptions,
+)
+from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope
+from azure.ai.ml.entities import Job
+from azure.ai.ml.exceptions import UserErrorException, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class VirtualClusterOperations:
+ """VirtualClusterOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ credentials: TokenCredential,
+ *,
+ _service_client_kwargs: Dict,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._resource_group_name = operation_scope.resource_group_name
+ self._subscription_id = operation_scope.subscription_id
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+ self._service_client_kwargs = _service_client_kwargs
+
+ """A (mostly) autogenerated rest client for the index service.
+
+ TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp
+ """
+ self._index_service = IndexServiceAPIs(
+ credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI)
+ def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]:
+ """List virtual clusters a user has access to.
+
+ :keyword scope: scope of the listing, "subscription" or None, defaults to None.
+ If None, list virtual clusters across all subscriptions a customer has access to.
+ :paramtype scope: str
+ :return: An iterator like instance of dictionaries.
+ :rtype: ~azure.core.paging.ItemPaged[Dict]
+ """
+
+ if scope is None:
+ subscription_list = None
+ elif scope.lower() == Scope.SUBSCRIPTION:
+ subscription_list = [self._subscription_id]
+ else:
+ message = f"Invalid scope: {scope}. Valid values are 'subscription' or None."
+ raise UserErrorException(message=message, no_personal_data_message=message)
+
+ try:
+ return cast(
+ Iterable[Dict],
+ get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list),
+ )
+ except ImportError as e:
+ raise UserErrorException(
+ message="Met ImportError when trying to list virtual clusters. "
+ "Please install azure-mgmt-resource to enable this feature; "
+ "and please install azure-mgmt-resource to enable listing virtual clusters "
+ "across all subscriptions a customer has access to."
+ ) from e
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI)
+ def list_jobs(self, name: str) -> Iterable[Job]:
+ """List of jobs that target the virtual cluster
+
+ :param name: Name of virtual cluster
+ :type name: str
+ :return: An iterable of jobs.
+ :rtype: Iterable[Job]
+ """
+
+ def make_id(entity_type: str) -> str:
+ return str(
+ LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format(
+ self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name
+ )
+ )
+
+ # List of virtual cluster ids to match
+ # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change
+ vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()]
+
+ filters = [
+ IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]),
+ IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]),
+ IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids),
+ ]
+ order = [
+ IndexEntitiesRequestOrder(
+ field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC
+ )
+ ]
+ index_entities_request = IndexEntitiesRequest(filters=filters, order=order)
+
+ # cspell:ignore entites
+ return cast(
+ Iterable[Job],
+ self._index_service.index_entities.get_entites_cross_region(
+ body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request),
+ cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str) -> Dict:
+ """
+ Get a virtual cluster resource. If name is provided, the virtual cluster
+ with the name in the subscription and resource group of the MLClient object
+ will be returned. If an ARM id is provided, a virtual cluster with the id will be returned.
+
+ :param name: Name or ARM ID of the virtual cluster.
+ :type name: str
+ :return: Virtual cluster object
+ :rtype: Dict
+ """
+
+ try:
+ arm_id = AzureResourceId(name)
+ sub_id = arm_id.subscription_id
+
+ return cast(
+ Dict,
+ get_generic_resource_by_id(
+ arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview"
+ ),
+ )
+ except ValidationException:
+ return cast(
+ Dict,
+ get_virtual_cluster_by_name(
+ name=name,
+ resource_group=self._resource_group_name,
+ subscription_id=self._subscription_id,
+ credential=self._credentials,
+ ),
+ )