aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from typing import Callable, Dict, Optional, TypeVar, cast

from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

T = TypeVar("T")
module_logger = logging.getLogger(__name__)


class OperationConfig(object):
    """This class is used to store common configurations that are shared across operation objects of an MLClient object.

    :param object: _description_
    :type object: _type_
    """

    def __init__(self, show_progress: bool, enable_telemetry: bool) -> None:
        self._show_progress = show_progress
        self._enable_telemetry = enable_telemetry

    @property
    def show_progress(self) -> bool:
        """Decide wether to display progress bars for long running operations.

        :return: show_progress
        :rtype: bool
        """
        return self._show_progress

    @property
    def enable_telemetry(self) -> bool:
        """Decide whether to enable telemetry for Jupyter Notebooks - telemetry cannot be enabled for other contexts.

        :return: enable_telemetry
        :rtype: bool
        """
        return self._enable_telemetry


class OperationScope(object):
    def __init__(
        self,
        subscription_id: str,
        resource_group_name: str,
        workspace_name: Optional[str],
        registry_name: Optional[str] = None,
        workspace_id: Optional[str] = None,
        workspace_location: Optional[str] = None,
    ):
        self._subscription_id = subscription_id
        self._resource_group_name = resource_group_name
        self._workspace_name = workspace_name
        self._registry_name = registry_name
        self._workspace_id = workspace_id
        self._workspace_location = workspace_location

    @property
    def subscription_id(self) -> str:
        return self._subscription_id

    @property
    def resource_group_name(self) -> str:
        return self._resource_group_name

    @property
    def workspace_name(self) -> Optional[str]:
        return self._workspace_name

    @workspace_name.setter
    def workspace_name(self, value: str) -> None:
        self._workspace_name = value

    @property
    def registry_name(self) -> Optional[str]:
        return self._registry_name

    @registry_name.setter
    def registry_name(self, value: str) -> None:
        self._registry_name = value


class _ScopeDependentOperations(object):
    def __init__(self, operation_scope: OperationScope, operation_config: OperationConfig):
        self._operation_scope = operation_scope
        self._operation_config = operation_config
        self._scope_kwargs: Dict = {
            "resource_group_name": self._operation_scope.resource_group_name,
        }

    @property  # type: ignore
    def _workspace_name(self) -> str:
        return cast(str, self._operation_scope.workspace_name)

    @property  # type: ignore
    def _registry_name(self) -> str:
        return cast(str, self._operation_scope.registry_name)

    @property
    def _subscription_id(self) -> str:
        return self._operation_scope.subscription_id

    @property
    def _resource_group_name(self) -> str:
        return self._operation_scope.resource_group_name

    @property
    def _show_progress(self) -> bool:
        return self._operation_config.show_progress

    @property
    def _enable_telemetry(self) -> bool:
        return self._operation_config.enable_telemetry


class OperationsContainer(object):
    def __init__(self):
        self._all_operations = {}

    @property
    def all_operations(self) -> Dict:
        return self._all_operations

    def add(self, name: str, operation: _ScopeDependentOperations) -> None:
        self._all_operations[name] = operation

    def get_operation(self, resource_type: str, type_check: Callable[[T], bool]) -> T:
        if resource_type in self.all_operations:
            operation = self.all_operations[resource_type]
            from unittest.mock import MagicMock

            if isinstance(operation, MagicMock) or type_check(operation):
                return operation
            msg = f"{resource_type} operations are initialized with wrong type: {type(operation)}."
            raise ValidationException(
                message=msg,
                no_personal_data_message=msg,
                error_category=ErrorCategory.USER_ERROR,
                target=ErrorTarget.JOB,
                error_type=ValidationErrorType.INVALID_VALUE,
            )
        msg = f"Operation {resource_type} is not available for this client."
        raise ValidationException(
            message=msg,
            no_personal_data_message=msg,
            error_category=ErrorCategory.USER_ERROR,
            target=ErrorTarget.JOB,
            error_type=ValidationErrorType.INVALID_VALUE,
        )