aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_compute/compute.py
blob: de18da5a24de3b432ddf8994267f7962d4dc0ac1 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

from abc import abstractmethod
from os import PathLike
from pathlib import Path
from typing import IO, Any, AnyStr, Dict, Optional, Union, cast

from azure.ai.ml._restclient.v2022_10_01_preview.models import ComputeResource
from azure.ai.ml._schema.compute.compute import ComputeSchema
from azure.ai.ml._utils.utils import dump_yaml_to_file
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, CommonYamlFields
from azure.ai.ml.constants._compute import ComputeType
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.entities._resource import Resource
from azure.ai.ml.entities._util import find_type_in_override
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException


class Compute(Resource, RestTranslatableMixin):
    """Base class for compute resources.

    This class should not be instantiated directly. Instead, use one of its subclasses.

    :param type: The compute type. Accepted values are "amlcompute", "computeinstance",
        "virtualmachine", "kubernetes", and "synapsespark".
    :type type: str
    :param name: Name of the compute resource.
    :type name: str
    :param location: The resource location. Defaults to workspace location.
    :type location: Optional[str]
    :param description: Description of the resource. Defaults to None.
    :type description: Optional[str]
    :param resource_id: ARM resource id of the underlying compute. Defaults to None.
    :type resource_id: Optional[str]
    :param tags: A set of tags. Contains resource tags defined as key/value pairs.
    :type tags: Optional[dict[str, str]]
    """

    def __init__(
        self,
        name: str,
        location: Optional[str] = None,
        description: Optional[str] = None,
        resource_id: Optional[str] = None,
        tags: Optional[Dict] = None,
        **kwargs: Any,
    ) -> None:
        self._type: Optional[str] = kwargs.pop("type", None)
        if self._type:
            self._type = self._type.lower()

        self._created_on: Optional[str] = kwargs.pop("created_on", None)
        self._provisioning_state: Optional[str] = kwargs.pop("provisioning_state", None)
        self._provisioning_errors: Optional[str] = kwargs.pop("provisioning_errors", None)

        super().__init__(name=name, description=description, **kwargs)
        self.resource_id = resource_id
        self.location = location
        self.tags = tags

    @property
    def type(self) -> Optional[str]:
        """The compute type.

        :return: The compute type.
        :rtype: Optional[str]
        """
        return self._type

    @property
    def created_on(self) -> Optional[str]:
        """The compute resource creation timestamp.

        :return: The compute resource creation timestamp.
        :rtype: Optional[str]
        """
        return self._created_on

    @property
    def provisioning_state(self) -> Optional[str]:
        """The compute resource's provisioning state.

        :return: The compute resource's provisioning state.
        :rtype: Optional[str]
        """
        return self._provisioning_state

    @property
    def provisioning_errors(self) -> Optional[str]:
        """The compute resource provisioning errors.

        :return: The compute resource provisioning errors.
        :rtype: Optional[str]
        """
        return self._provisioning_errors

    def _to_rest_object(self) -> ComputeResource:
        pass

    @classmethod
    def _from_rest_object(cls, obj: ComputeResource) -> "Compute":
        from azure.ai.ml.entities import (
            AmlCompute,
            ComputeInstance,
            KubernetesCompute,
            SynapseSparkCompute,
            UnsupportedCompute,
            VirtualMachineCompute,
        )

        mapping = {
            ComputeType.AMLCOMPUTE.lower(): AmlCompute,
            ComputeType.COMPUTEINSTANCE.lower(): ComputeInstance,
            ComputeType.VIRTUALMACHINE.lower(): VirtualMachineCompute,
            ComputeType.KUBERNETES.lower(): KubernetesCompute,
            ComputeType.SYNAPSESPARK.lower(): SynapseSparkCompute,
        }
        compute_type = obj.properties.compute_type.lower() if obj.properties.compute_type else None

        class_type = cast(
            Optional[Union[AmlCompute, ComputeInstance, VirtualMachineCompute, KubernetesCompute, SynapseSparkCompute]],
            mapping.get(compute_type, None),  # type: ignore
        )
        if class_type:
            return class_type._load_from_rest(obj)
        _unsupported_from_rest: Compute = UnsupportedCompute._load_from_rest(obj)
        return _unsupported_from_rest

    @classmethod
    @abstractmethod
    def _load_from_rest(cls, rest_obj: ComputeResource) -> "Compute":
        pass

    def _set_full_subnet_name(self, subscription_id: str, rg: str) -> None:
        pass

    def dump(self, dest: Union[str, PathLike, IO[AnyStr]], **kwargs: Any) -> None:
        """Dump the compute content into a file in yaml format.

        :param dest: The destination to receive this compute's content.
            Must be either a path to a local file, or an already-open file stream.
            If dest is a file path, a new file will be created,
            and an exception is raised if the file exists.
            If dest is an open file, the file will be written to directly,
            and an exception will be raised if the file is not writable.'.
        :type dest: Union[PathLike, str, IO[AnyStr]]
        """
        path = kwargs.pop("path", None)
        yaml_serialized = self._to_dict()
        dump_yaml_to_file(dest, yaml_serialized, default_flow_style=False, path=path, **kwargs)

    def _to_dict(self) -> Dict:
        res: dict = ComputeSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
        return res

    @classmethod
    def _load(
        cls,
        data: Optional[Dict] = None,
        yaml_path: Optional[Union[PathLike, str]] = None,
        params_override: Optional[list] = None,
        **kwargs: Any,
    ) -> "Compute":
        data = data or {}
        params_override = params_override or []
        context = {
            BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"),
            PARAMS_OVERRIDE_KEY: params_override,
        }
        from azure.ai.ml.entities import (
            AmlCompute,
            ComputeInstance,
            KubernetesCompute,
            SynapseSparkCompute,
            VirtualMachineCompute,
        )

        type_in_override = find_type_in_override(params_override) if params_override else None
        compute_type = type_in_override or data.get(CommonYamlFields.TYPE, None)  # override takes the priority
        if compute_type:
            if compute_type.lower() == ComputeType.VIRTUALMACHINE:
                _vm_load_from_dict: Compute = VirtualMachineCompute._load_from_dict(data, context, **kwargs)
                return _vm_load_from_dict
            if compute_type.lower() == ComputeType.AMLCOMPUTE:
                _aml_load_from_dict: Compute = AmlCompute._load_from_dict(data, context, **kwargs)
                return _aml_load_from_dict
            if compute_type.lower() == ComputeType.COMPUTEINSTANCE:
                _compute_instance_load_from_dict: Compute = ComputeInstance._load_from_dict(data, context, **kwargs)
                return _compute_instance_load_from_dict
            if compute_type.lower() == ComputeType.KUBERNETES:
                _kub_load_from_dict: Compute = KubernetesCompute._load_from_dict(data, context, **kwargs)
                return _kub_load_from_dict
            if compute_type.lower() == ComputeType.SYNAPSESPARK:
                _synapse_spark_load_from_dict: Compute = SynapseSparkCompute._load_from_dict(data, context, **kwargs)
                return _synapse_spark_load_from_dict
        msg = f"Unknown compute type: {compute_type}"
        raise ValidationException(
            message=msg,
            target=ErrorTarget.COMPUTE,
            no_personal_data_message=msg,
            error_category=ErrorCategory.USER_ERROR,
        )

    @classmethod
    @abstractmethod
    def _load_from_dict(cls, data: Dict, context: Dict, **kwargs: Any) -> "Compute":
        pass


class NetworkSettings:
    """Network settings for a compute resource. If the workspace and VNet are in different resource groups,
    please provide the full URI for subnet and leave vnet_name as None.

    :param vnet_name: The virtual network name.
    :type vnet_name: Optional[str]
    :param subnet: The subnet name.
    :type subnet: Optional[str]

    .. admonition:: Example:

        .. literalinclude:: ../samples/ml_samples_compute.py
            :start-after: [START network_settings]
            :end-before: [END network_settings]
            :language: python
            :dedent: 8
            :caption: Configuring NetworkSettings for an AmlCompute object.
    """

    def __init__(
        self,
        *,
        vnet_name: Optional[str] = None,
        subnet: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        self.vnet_name = vnet_name
        self.subnet = subnet
        self._public_ip_address: str = kwargs.pop("public_ip_address", None)
        self._private_ip_address: str = kwargs.pop("private_ip_address", None)

    @property
    def public_ip_address(self) -> str:
        """Public IP address of the compute instance.

        :return: Public IP address.
        :rtype: str
        """
        return self._public_ip_address

    @property
    def private_ip_address(self) -> str:
        """Private IP address of the compute instance.

        :return: Private IP address.
        :rtype: str
        """
        return self._private_ip_address