about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py174
1 files changed, 174 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
new file mode 100644
index 00000000..256664ad
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
@@ -0,0 +1,174 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Iterable, Optional, cast
+
+from azure.ai.ml._scope_dependent_operations import OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AzureResourceId
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._virtual_cluster_utils import (
+    CrossRegionIndexEntitiesRequest,
+    IndexEntitiesRequest,
+    IndexEntitiesRequestFilter,
+    IndexEntitiesRequestOrder,
+    IndexEntitiesRequestOrderDirection,
+    IndexServiceAPIs,
+    index_entity_response_to_job,
+)
+from azure.ai.ml._utils.azure_resource_utils import (
+    get_generic_resource_by_id,
+    get_virtual_cluster_by_name,
+    get_virtual_clusters_from_subscriptions,
+)
+from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope
+from azure.ai.ml.entities import Job
+from azure.ai.ml.exceptions import UserErrorException, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class VirtualClusterOperations:
+    """VirtualClusterOperations.
+
+    You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+    for you and attaches it as an attribute.
+    """
+
+    def __init__(
+        self,
+        operation_scope: OperationScope,
+        credentials: TokenCredential,
+        *,
+        _service_client_kwargs: Dict,
+        **kwargs: Dict,
+    ):
+        ops_logger.update_filter()
+        self._resource_group_name = operation_scope.resource_group_name
+        self._subscription_id = operation_scope.subscription_id
+        self._credentials = credentials
+        self._init_kwargs = kwargs
+        self._service_client_kwargs = _service_client_kwargs
+
+        """A (mostly) autogenerated rest client for the index service.
+
+        TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp
+        """
+        self._index_service = IndexServiceAPIs(
+            credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs
+        )
+
+    @distributed_trace
+    @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI)
+    def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]:
+        """List virtual clusters a user has access to.
+
+        :keyword scope: scope of the listing, "subscription" or None, defaults to None.
+            If None, list virtual clusters across all subscriptions a customer has access to.
+        :paramtype scope: str
+        :return: An iterator like instance of dictionaries.
+        :rtype: ~azure.core.paging.ItemPaged[Dict]
+        """
+
+        if scope is None:
+            subscription_list = None
+        elif scope.lower() == Scope.SUBSCRIPTION:
+            subscription_list = [self._subscription_id]
+        else:
+            message = f"Invalid scope: {scope}. Valid values are 'subscription' or None."
+            raise UserErrorException(message=message, no_personal_data_message=message)
+
+        try:
+            return cast(
+                Iterable[Dict],
+                get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list),
+            )
+        except ImportError as e:
+            raise UserErrorException(
+                message="Met ImportError when trying to list virtual clusters. "
+                "Please install azure-mgmt-resource to enable this feature; "
+                "and please install azure-mgmt-resource to enable listing virtual clusters "
+                "across all subscriptions a customer has access to."
+            ) from e
+
+    @distributed_trace
+    @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI)
+    def list_jobs(self, name: str) -> Iterable[Job]:
+        """List of jobs that target the virtual cluster
+
+        :param name: Name of virtual cluster
+        :type name: str
+        :return: An iterable of jobs.
+        :rtype: Iterable[Job]
+        """
+
+        def make_id(entity_type: str) -> str:
+            return str(
+                LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format(
+                    self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name
+                )
+            )
+
+        # List of virtual cluster ids to match
+        # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change
+        vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()]
+
+        filters = [
+            IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]),
+            IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]),
+            IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids),
+        ]
+        order = [
+            IndexEntitiesRequestOrder(
+                field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC
+            )
+        ]
+        index_entities_request = IndexEntitiesRequest(filters=filters, order=order)
+
+        # cspell:ignore entites
+        return cast(
+            Iterable[Job],
+            self._index_service.index_entities.get_entites_cross_region(
+                body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request),
+                cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs],
+            ),
+        )
+
+    @distributed_trace
+    @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI)
+    def get(self, name: str) -> Dict:
+        """
+        Get a virtual cluster resource. If name is provided, the virtual cluster
+        with the name in the subscription and resource group of the MLClient object
+        will be returned. If an ARM id is provided, a virtual cluster with the id will be returned.
+
+        :param name: Name or ARM ID of the virtual cluster.
+        :type name: str
+        :return: Virtual cluster object
+        :rtype: Dict
+        """
+
+        try:
+            arm_id = AzureResourceId(name)
+            sub_id = arm_id.subscription_id
+
+            return cast(
+                Dict,
+                get_generic_resource_by_id(
+                    arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview"
+                ),
+            )
+        except ValidationException:
+            return cast(
+                Dict,
+                get_virtual_cluster_by_name(
+                    name=name,
+                    resource_group=self._resource_group_name,
+                    subscription_id=self._subscription_id,
+                    credential=self._credentials,
+                ),
+            )