# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import logging from typing import Callable, Dict, Optional, TypeVar, cast from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException T = TypeVar("T") module_logger = logging.getLogger(__name__) class OperationConfig(object): """This class is used to store common configurations that are shared across operation objects of an MLClient object. :param object: _description_ :type object: _type_ """ def __init__(self, show_progress: bool, enable_telemetry: bool) -> None: self._show_progress = show_progress self._enable_telemetry = enable_telemetry @property def show_progress(self) -> bool: """Decide wether to display progress bars for long running operations. :return: show_progress :rtype: bool """ return self._show_progress @property def enable_telemetry(self) -> bool: """Decide whether to enable telemetry for Jupyter Notebooks - telemetry cannot be enabled for other contexts. :return: enable_telemetry :rtype: bool """ return self._enable_telemetry class OperationScope(object): def __init__( self, subscription_id: str, resource_group_name: str, workspace_name: Optional[str], registry_name: Optional[str] = None, workspace_id: Optional[str] = None, workspace_location: Optional[str] = None, ): self._subscription_id = subscription_id self._resource_group_name = resource_group_name self._workspace_name = workspace_name self._registry_name = registry_name self._workspace_id = workspace_id self._workspace_location = workspace_location @property def subscription_id(self) -> str: return self._subscription_id @property def resource_group_name(self) -> str: return self._resource_group_name @property def workspace_name(self) -> Optional[str]: return self._workspace_name @workspace_name.setter def workspace_name(self, value: str) -> None: self._workspace_name = value @property def registry_name(self) -> Optional[str]: return self._registry_name @registry_name.setter def registry_name(self, value: str) -> None: self._registry_name = value class _ScopeDependentOperations(object): def __init__(self, operation_scope: OperationScope, operation_config: OperationConfig): self._operation_scope = operation_scope self._operation_config = operation_config self._scope_kwargs: Dict = { "resource_group_name": self._operation_scope.resource_group_name, } @property # type: ignore def _workspace_name(self) -> str: return cast(str, self._operation_scope.workspace_name) @property # type: ignore def _registry_name(self) -> str: return cast(str, self._operation_scope.registry_name) @property def _subscription_id(self) -> str: return self._operation_scope.subscription_id @property def _resource_group_name(self) -> str: return self._operation_scope.resource_group_name @property def _show_progress(self) -> bool: return self._operation_config.show_progress @property def _enable_telemetry(self) -> bool: return self._operation_config.enable_telemetry class OperationsContainer(object): def __init__(self): self._all_operations = {} @property def all_operations(self) -> Dict: return self._all_operations def add(self, name: str, operation: _ScopeDependentOperations) -> None: self._all_operations[name] = operation def get_operation(self, resource_type: str, type_check: Callable[[T], bool]) -> T: if resource_type in self.all_operations: operation = self.all_operations[resource_type] from unittest.mock import MagicMock if isinstance(operation, MagicMock) or type_check(operation): return operation msg = f"{resource_type} operations are initialized with wrong type: {type(operation)}." raise ValidationException( message=msg, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.JOB, error_type=ValidationErrorType.INVALID_VALUE, ) msg = f"Operation {resource_type} is not available for this client." raise ValidationException( message=msg, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, target=ErrorTarget.JOB, error_type=ValidationErrorType.INVALID_VALUE, )