about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/compute.py
blob: ff91a8149a41fdd2b14f78214895653f09341122 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from azure.ai.ml._exception_helper import log_and_raise_error
from azure.ai.ml._restclient.v2023_06_01_preview.models import AmlTokenComputeIdentity, MonitorServerlessSparkCompute
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException


class ServerlessSparkCompute:
    """Serverless Spark compute.

    :param runtime_version: The runtime version of the compute.
    :type runtime_version: str
    :param instance_type: The instance type of the compute.
    :type instance_type: str
    """

    def __init__(
        self,
        *,
        runtime_version: str,
        instance_type: str,
    ):
        self.runtime_version = runtime_version
        self.instance_type = instance_type

    def _to_rest_object(self) -> MonitorServerlessSparkCompute:
        self._validate()
        return MonitorServerlessSparkCompute(
            runtime_version=self.runtime_version,
            instance_type=self.instance_type,
            compute_identity=AmlTokenComputeIdentity(
                compute_identity_type="AmlToken",
            ),
        )

    @classmethod
    def _from_rest_object(cls, obj: MonitorServerlessSparkCompute) -> "ServerlessSparkCompute":
        return cls(
            runtime_version=obj.runtime_version,
            instance_type=obj.instance_type,
        )

    def _validate(self) -> None:
        if self.runtime_version != "3.4":
            msg = "Compute runtime version must be 3.4"
            err = ValidationException(
                message=msg,
                target=ErrorTarget.MODEL_MONITORING,
                no_personal_data_message=msg,
                error_category=ErrorCategory.USER_ERROR,
                error_type=ValidationErrorType.MISSING_FIELD,
            )
            log_and_raise_error(err)