aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

from typing import Any, Dict, Iterable, List, Optional, Union, cast

from marshmallow import ValidationError

from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkProvisionOptions
from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
from azure.ai.ml._utils._http_utils import HttpPipeline
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml._utils.utils import (
    _get_workspace_base_url,
    get_resource_and_group_name_from_resource_id,
    get_resource_group_name_from_resource_group_id,
    modified_operation_client,
)
from azure.ai.ml.constants._common import AzureMLResourceType, Scope, WorkspaceKind
from azure.ai.ml.entities import (
    DiagnoseRequestProperties,
    DiagnoseResponseResult,
    DiagnoseResponseResultValue,
    DiagnoseWorkspaceParameters,
    ManagedNetworkProvisionStatus,
    Workspace,
    WorkspaceKeys,
)
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.core.credentials import TokenCredential
from azure.core.exceptions import HttpResponseError
from azure.core.polling import LROPoller
from azure.core.tracing.decorator import distributed_trace

from ._workspace_operations_base import WorkspaceOperationsBase

ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger


class WorkspaceOperations(WorkspaceOperationsBase):
    """Handles workspaces and its subclasses, hubs and projects.

    You should not instantiate this class directly. Instead, you should create
    an MLClient instance that instantiates it for you and attaches it as an attribute.
    """

    def __init__(
        self,
        operation_scope: OperationScope,
        service_client: ServiceClient102024Preview,
        all_operations: OperationsContainer,
        credentials: Optional[TokenCredential] = None,
        **kwargs: Any,
    ):
        self.dataplane_workspace_operations = (
            kwargs.pop("dataplane_client").workspaces if kwargs.get("dataplane_client") else None
        )
        self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline", None)
        ops_logger.update_filter()
        self._provision_network_operation = service_client.managed_network_provisions
        super().__init__(
            operation_scope=operation_scope,
            service_client=service_client,
            all_operations=all_operations,
            credentials=credentials,
            **kwargs,
        )

    @monitor_with_activity(ops_logger, "Workspace.List", ActivityType.PUBLICAPI)
    def list(
        self, *, scope: str = Scope.RESOURCE_GROUP, filtered_kinds: Optional[Union[str, List[str]]] = None
    ) -> Iterable[Workspace]:
        """List all Workspaces that the user has access to in the current resource group or subscription.

        :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group"
        :paramtype scope: str
        :keyword filtered_kinds: The kinds of workspaces to list. If not provided, all workspaces varieties will
            be listed. Accepts either a single kind, or a list of them.
            Valid kind options include: "default", "project", and "hub".
        :return: An iterator like instance of Workspace objects
        :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Workspace]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_list]
                :end-before: [END workspace_list]
                :language: python
                :dedent: 8
                :caption: List the workspaces by resource group or subscription.
        """

        # Kind should be converted to a comma-separating string if multiple values are supplied.
        formatted_kinds = filtered_kinds
        if filtered_kinds and not isinstance(filtered_kinds, str):
            formatted_kinds = ",".join(filtered_kinds)  # type: ignore[arg-type]

        if scope == Scope.SUBSCRIPTION:
            return cast(
                Iterable[Workspace],
                self._operation.list_by_subscription(
                    cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs],
                    kind=formatted_kinds,
                ),
            )
        return cast(
            Iterable[Workspace],
            self._operation.list_by_resource_group(
                self._resource_group_name,
                cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs],
                kind=formatted_kinds,
            ),
        )

    @monitor_with_activity(ops_logger, "Workspace.Get", ActivityType.PUBLICAPI)
    @distributed_trace
    # pylint: disable=arguments-renamed
    def get(self, name: Optional[str] = None, **kwargs: Dict) -> Optional[Workspace]:
        """Get a Workspace by name.

        :param name: Name of the workspace.
        :type name: str
        :return: The workspace with the provided name.
        :rtype: ~azure.ai.ml.entities.Workspace

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_get]
                :end-before: [END workspace_get]
                :language: python
                :dedent: 8
                :caption: Get the workspace with the given name.
        """

        return super().get(workspace_name=name, **kwargs)

    @monitor_with_activity(ops_logger, "Workspace.Get_Keys", ActivityType.PUBLICAPI)
    @distributed_trace
    def get_keys(self, name: Optional[str] = None) -> Optional[WorkspaceKeys]:
        """Get WorkspaceKeys by workspace name.

        :param name: Name of the workspace.
        :type name: str
        :return: Keys of workspace dependent resources.
        :rtype: ~azure.ai.ml.entities.WorkspaceKeys

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_get_keys]
                :end-before: [END workspace_get_keys]
                :language: python
                :dedent: 8
                :caption: Get the workspace keys for the workspace with the given name.
        """
        workspace_name = self._check_workspace_name(name)
        obj = self._operation.list_keys(self._resource_group_name, workspace_name)
        return WorkspaceKeys._from_rest_object(obj)

    @monitor_with_activity(ops_logger, "Workspace.BeginSyncKeys", ActivityType.PUBLICAPI)
    @distributed_trace
    def begin_sync_keys(self, name: Optional[str] = None) -> LROPoller[None]:
        """Triggers the workspace to immediately synchronize keys. If keys for any resource in the workspace are
        changed, it can take around an hour for them to automatically be updated. This function enables keys to be
        updated upon request. An example scenario is needing immediate access to storage after regenerating storage
        keys.

        :param name: Name of the workspace.
        :type name: str
        :return: An instance of LROPoller that returns either None or the sync keys result.
        :rtype: ~azure.core.polling.LROPoller[None]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_sync_keys]
                :end-before: [END workspace_sync_keys]
                :language: python
                :dedent: 8
                :caption: Begin sync keys for the workspace with the given name.
        """
        workspace_name = self._check_workspace_name(name)
        return self._operation.begin_resync_keys(self._resource_group_name, workspace_name)

    @monitor_with_activity(ops_logger, "Workspace.BeginProvisionNetwork", ActivityType.PUBLICAPI)
    @distributed_trace
    def begin_provision_network(
        self,
        *,
        workspace_name: Optional[str] = None,
        include_spark: bool = False,
        **kwargs: Any,
    ) -> LROPoller[ManagedNetworkProvisionStatus]:
        """Triggers the workspace to provision the managed network. Specifying spark enabled
        as true prepares the workspace managed network for supporting Spark.

        :keyword workspace_name: Name of the workspace.
        :paramtype workspace_name: str
        :keyword include_spark: Whether the workspace managed network should prepare to support Spark.
        :paramtype include_space: bool
        :return: An instance of LROPoller.
        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ManagedNetworkProvisionStatus]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_provision_network]
                :end-before: [END workspace_provision_network]
                :language: python
                :dedent: 8
                :caption: Begin provision network for a workspace with managed network.
        """
        workspace_name = self._check_workspace_name(workspace_name)

        poller = self._provision_network_operation.begin_provision_managed_network(
            self._resource_group_name,
            workspace_name,
            ManagedNetworkProvisionOptions(include_spark=include_spark),
            polling=True,
            cls=lambda response, deserialized, headers: ManagedNetworkProvisionStatus._from_rest_object(deserialized),
            **kwargs,
        )
        module_logger.info("Provision network request initiated for workspace: %s\n", workspace_name)
        return poller

    @monitor_with_activity(ops_logger, "Workspace.BeginCreate", ActivityType.PUBLICAPI)
    @distributed_trace
    # pylint: disable=arguments-differ
    def begin_create(
        self,
        workspace: Workspace,
        update_dependent_resources: bool = False,
        **kwargs: Any,
    ) -> LROPoller[Workspace]:
        """Create a new Azure Machine Learning Workspace.

        Returns the workspace if already exists.

        :param workspace: Workspace definition.
        :type workspace: ~azure.ai.ml.entities.Workspace
        :param update_dependent_resources: Whether to update dependent resources, defaults to False.
        :type update_dependent_resources: boolean
        :return: An instance of LROPoller that returns a Workspace.
        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_begin_create]
                :end-before: [END workspace_begin_create]
                :language: python
                :dedent: 8
                :caption: Begin create for a workspace.
        """
        # Add hub values to project if possible
        if workspace._kind == WorkspaceKind.PROJECT:
            try:
                parent_name = workspace._hub_id.split("/")[-1] if workspace._hub_id else ""
                parent = self.get(parent_name)
                if parent:
                    # Project location can not differ from hub, so try to force match them if possible.
                    workspace.location = parent.location
                    # Project's technically don't save their PNA, since it implicitly matches their parent's.
                    # However, some PNA-dependent code is run server-side before that alignment is made, so make sure
                    # they're aligned before the request hits the server.
                    workspace.public_network_access = parent.public_network_access
            except HttpResponseError:
                module_logger.warning("Failed to get parent hub for project, some values won't be transferred:")
        try:
            return super().begin_create(workspace, update_dependent_resources=update_dependent_resources, **kwargs)
        except HttpResponseError as error:
            if error.status_code == 403 and workspace._kind == WorkspaceKind.PROJECT:
                resource_group = kwargs.get("resource_group") or self._resource_group_name
                hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id)
                rest_workspace_obj = self._operation.get(resource_group, hub_name)
                hub_default_project_resource_group = get_resource_group_name_from_resource_group_id(
                    rest_workspace_obj.workspace_hub_config.default_workspace_resource_group
                )
                # we only want to try joining the workspaceHub when the default workspace resource group
                # is same with the user provided resource group.
                if hub_default_project_resource_group == resource_group:
                    log_msg = (
                        "User lacked permission to create project workspace,"
                        + "trying to join the workspaceHub default resource group."
                    )
                    module_logger.info(log_msg)
                    return self._begin_join(workspace, **kwargs)
            raise error

    @monitor_with_activity(ops_logger, "Workspace.BeginUpdate", ActivityType.PUBLICAPI)
    @distributed_trace
    def begin_update(
        self,
        workspace: Workspace,
        *,
        update_dependent_resources: bool = False,
        **kwargs: Any,
    ) -> LROPoller[Workspace]:
        """Updates a Azure Machine Learning Workspace.

        :param workspace: Workspace definition.
        :type workspace: ~azure.ai.ml.entities.Workspace
        :keyword update_dependent_resources: Whether to update dependent resources, defaults to False.
        :paramtype update_dependent_resources: boolean
        :return: An instance of LROPoller that returns a Workspace.
        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_begin_update]
                :end-before: [END workspace_begin_update]
                :language: python
                :dedent: 8
                :caption: Begin update for a workspace.
        """
        return super().begin_update(workspace, update_dependent_resources=update_dependent_resources, **kwargs)

    @monitor_with_activity(ops_logger, "Workspace.BeginDelete", ActivityType.PUBLICAPI)
    @distributed_trace
    def begin_delete(
        self, name: str, *, delete_dependent_resources: bool, permanently_delete: bool = False, **kwargs: Dict
    ) -> LROPoller[None]:
        """Delete a workspace.

        :param name: Name of the workspace
        :type name: str
        :keyword delete_dependent_resources: Whether to delete resources associated with the workspace,
            i.e., container registry, storage account, key vault, application insights, log analytics.
            The default is False. Set to True to delete these resources.
        :paramtype delete_dependent_resources: bool
        :keyword permanently_delete: Workspaces are soft-deleted by default to allow recovery of workspace data.
            Set this flag to true to override the soft-delete behavior and permanently delete your workspace.
        :paramtype permanently_delete: bool
        :return: A poller to track the operation status.
        :rtype: ~azure.core.polling.LROPoller[None]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_begin_delete]
                :end-before: [END workspace_begin_delete]
                :language: python
                :dedent: 8
                :caption: Begin permanent (force) deletion for a workspace and delete dependent resources.
        """
        return super().begin_delete(
            name, delete_dependent_resources=delete_dependent_resources, permanently_delete=permanently_delete, **kwargs
        )

    @distributed_trace
    @monitor_with_activity(ops_logger, "Workspace.BeginDiagnose", ActivityType.PUBLICAPI)
    def begin_diagnose(self, name: str, **kwargs: Dict) -> LROPoller[DiagnoseResponseResultValue]:
        """Diagnose workspace setup problems.

        If your workspace is not working as expected, you can run this diagnosis to
        check if the workspace has been broken.
        For private endpoint workspace, it will also help check if the network
        setup to this workspace and its dependent resource has problems or not.

        :param name: Name of the workspace
        :type name: str
        :return: A poller to track the operation status.
        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.DiagnoseResponseResultValue]

        .. admonition:: Example:

            .. literalinclude:: ../samples/ml_samples_workspace.py
                :start-after: [START workspace_begin_diagnose]
                :end-before: [END workspace_begin_diagnose]
                :language: python
                :dedent: 8
                :caption: Begin diagnose operation for a workspace.
        """
        resource_group = kwargs.get("resource_group") or self._resource_group_name
        parameters = DiagnoseWorkspaceParameters(value=DiagnoseRequestProperties())._to_rest_object()

        # pylint: disable=unused-argument, docstring-missing-param
        def callback(_: Any, deserialized: Any, args: Any) -> Optional[DiagnoseResponseResultValue]:
            """Callback to be called after completion

            :return: DiagnoseResponseResult deserialized.
            :rtype: ~azure.ai.ml.entities.DiagnoseResponseResult
            """
            diagnose_response_result = DiagnoseResponseResult._from_rest_object(deserialized)
            res = None
            if diagnose_response_result:
                res = diagnose_response_result.value
            return res

        poller = self._operation.begin_diagnose(resource_group, name, parameters, polling=True, cls=callback)
        module_logger.info("Diagnose request initiated for workspace: %s\n", name)
        return poller

    @distributed_trace
    def _begin_join(self, workspace: Workspace, **kwargs: Dict) -> LROPoller[Workspace]:
        """Join a WorkspaceHub by creating a project workspace under workspaceHub's default resource group.

        :param workspace: Project workspace definition to create
        :type workspace: Workspace
        :return: An instance of LROPoller that returns a project Workspace.
        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
        """
        if not workspace._hub_id:
            raise ValidationError(
                "{0} is not a Project workspace, join operation can only perform with workspaceHub provided".format(
                    workspace.name
                )
            )

        resource_group = kwargs.get("resource_group") or self._resource_group_name
        hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id)
        rest_workspace_obj = self._operation.get(resource_group, hub_name)

        # override the location to the same as the workspaceHub
        workspace.location = rest_workspace_obj.location
        # set this to system assigned so ARM will create MSI
        if not hasattr(workspace, "identity") or not workspace.identity:
            workspace.identity = IdentityConfiguration(type="system_assigned")

        workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
        workspace_base_uri = _get_workspace_base_url(workspace_operations, hub_name, self._requests_pipeline)

        # pylint:disable=unused-argument
        def callback(_: Any, deserialized: Any, args: Any) -> Optional[Workspace]:
            return Workspace._from_rest_object(deserialized)

        with modified_operation_client(self.dataplane_workspace_operations, workspace_base_uri):
            result = self.dataplane_workspace_operations.begin_hub_join(  # type: ignore
                resource_group_name=resource_group,
                workspace_name=hub_name,
                project_workspace_name=workspace.name,
                body=workspace._to_rest_object(),
                cls=callback,
                **self._init_kwargs,
            )
            return result