aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_scope_dependent_operations.py
blob: 6ea92a157b25ebedbc28ae78735f2bd53d0e9ba6 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from typing import Callable, Dict, Optional, TypeVar, cast

from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

T = TypeVar("T")
module_logger = logging.getLogger(__name__)


class OperationConfig(object):
    """This class is used to store common configurations that are shared across operation objects of an MLClient object.

    :param object: _description_
    :type object: _type_
    """

    def __init__(self, show_progress: bool, enable_telemetry: bool) -> None:
        self._show_progress = show_progress
        self._enable_telemetry = enable_telemetry

    @property
    def show_progress(self) -> bool:
        """Decide wether to display progress bars for long running operations.

        :return: show_progress
        :rtype: bool
        """
        return self._show_progress

    @property
    def enable_telemetry(self) -> bool:
        """Decide whether to enable telemetry for Jupyter Notebooks - telemetry cannot be enabled for other contexts.

        :return: enable_telemetry
        :rtype: bool
        """
        return self._enable_telemetry


class OperationScope(object):
    def __init__(
        self,
        subscription_id: str,
        resource_group_name: str,
        workspace_name: Optional[str],
        registry_name: Optional[str] = None,
        workspace_id: Optional[str] = None,
        workspace_location: Optional[str] = None,
    ):
        self._subscription_id = subscription_id
        self._resource_group_name = resource_group_name
        self._workspace_name = workspace_name
        self._registry_name = registry_name
        self._workspace_id = workspace_id
        self._workspace_location = workspace_location

    @property
    def subscription_id(self) -> str:
        return self._subscription_id

    @property
    def resource_group_name(self) -> str:
        return self._resource_group_name

    @property
    def workspace_name(self) -> Optional[str]:
        return self._workspace_name

    @workspace_name.setter
    def workspace_name(self, value: str) -> None:
        self._workspace_name = value

    @property
    def registry_name(self) -> Optional[str]:
        return self._registry_name

    @registry_name.setter
    def registry_name(self, value: str) -> None:
        self._registry_name = value


class _ScopeDependentOperations(object):
    def __init__(self, operation_scope: OperationScope, operation_config: OperationConfig):
        self._operation_scope = operation_scope
        self._operation_config = operation_config
        self._scope_kwargs: Dict = {
            "resource_group_name": self._operation_scope.resource_group_name,
        }

    @property  # type: ignore
    def _workspace_name(self) -> str:
        return cast(str, self._operation_scope.workspace_name)

    @property  # type: ignore
    def _registry_name(self) -> str:
        return cast(str, self._operation_scope.registry_name)

    @property
    def _subscription_id(self) -> str:
        return self._operation_scope.subscription_id

    @property
    def _resource_group_name(self) -> str:
        return self._operation_scope.resource_group_name

    @property
    def _show_progress(self) -> bool:
        return self._operation_config.show_progress

    @property
    def _enable_telemetry(self) -> bool:
        return self._operation_config.enable_telemetry


class OperationsContainer(object):
    def __init__(self):
        self._all_operations = {}

    @property
    def all_operations(self) -> Dict:
        return self._all_operations

    def add(self, name: str, operation: _ScopeDependentOperations) -> None:
        self._all_operations[name] = operation

    def get_operation(self, resource_type: str, type_check: Callable[[T], bool]) -> T:
        if resource_type in self.all_operations:
            operation = self.all_operations[resource_type]
            from unittest.mock import MagicMock

            if isinstance(operation, MagicMock) or type_check(operation):
                return operation
            msg = f"{resource_type} operations are initialized with wrong type: {type(operation)}."
            raise ValidationException(
                message=msg,
                no_personal_data_message=msg,
                error_category=ErrorCategory.USER_ERROR,
                target=ErrorTarget.JOB,
                error_type=ValidationErrorType.INVALID_VALUE,
            )
        msg = f"Operation {resource_type} is not available for this client."
        raise ValidationException(
            message=msg,
            no_personal_data_message=msg,
            error_category=ErrorCategory.USER_ERROR,
            target=ErrorTarget.JOB,
            error_type=ValidationErrorType.INVALID_VALUE,
        )