about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/compute_configuration.py
blob: dcc00825966fc15f8cc20c1ab7320c8eeb13c441 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
import logging
from typing import Any, Dict, Optional

from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import ComputeConfiguration as RestComputeConfiguration
from azure.ai.ml.constants._common import LOCAL_COMPUTE_TARGET
from azure.ai.ml.constants._job.job import JobComputePropertyFields
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin

module_logger = logging.getLogger(__name__)


class ComputeConfiguration(RestTranslatableMixin, DictMixin):
    """Compute resource configuration

    :param target: The compute target.
    :type target: Optional[str]
    :param instance_count: The number of instances.
    :type instance_count: Optional[int]
    :param is_local: Specifies if the compute will be on the local machine.
    :type is_local: Optional[bool]
    :param location: The location of the compute resource.
    :type location: Optional[str]
    :param properties: The resource properties
    :type properties: Optional[Dict[str, Any]]
    :param deserialize_properties: Specifies if property bag should be deserialized. Defaults to False.
    :type deserialize_properties: bool
    """

    def __init__(
        self,
        *,
        target: Optional[str] = None,
        instance_count: Optional[int] = None,
        is_local: Optional[bool] = None,
        instance_type: Optional[str] = None,
        location: Optional[str] = None,
        properties: Optional[Dict[str, Any]] = None,
        deserialize_properties: bool = False,
    ) -> None:
        self.instance_count = instance_count
        self.target = target or LOCAL_COMPUTE_TARGET
        self.is_local = is_local or self.target == LOCAL_COMPUTE_TARGET
        self.instance_type = instance_type
        self.location = location
        self.properties = properties
        if deserialize_properties and properties and self.properties is not None:
            for key, value in self.properties.items():
                try:
                    self.properties[key] = json.loads(value)
                except Exception:  # pylint: disable=W0718
                    # keep serialized string if load fails
                    pass

    def _to_rest_object(self) -> RestComputeConfiguration:
        if self.properties:
            serialized_properties = {}
            for key, value in self.properties.items():
                try:
                    if key.lower() == JobComputePropertyFields.SINGULARITY.lower():
                        # Map Singularity -> AISupercomputer in SDK until MFE does mapping
                        key = JobComputePropertyFields.AISUPERCOMPUTER
                    # Ensure keymatch is case invariant
                    elif key.lower() == JobComputePropertyFields.AISUPERCOMPUTER.lower():
                        key = JobComputePropertyFields.AISUPERCOMPUTER
                    serialized_properties[key] = json.dumps(value)
                except Exception:  # pylint: disable=W0718
                    pass
        else:
            serialized_properties = None
        return RestComputeConfiguration(
            target=self.target if not self.is_local else None,
            is_local=self.is_local,
            instance_count=self.instance_count,
            instance_type=self.instance_type,
            location=self.location,
            properties=serialized_properties,
        )

    @classmethod
    def _from_rest_object(cls, obj: RestComputeConfiguration) -> "ComputeConfiguration":
        return ComputeConfiguration(
            target=obj.target,
            is_local=obj.is_local,
            instance_count=obj.instance_count,
            location=obj.location,
            instance_type=obj.instance_type,
            properties=obj.properties,
            deserialize_properties=True,
        )

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, ComputeConfiguration):
            return NotImplemented
        return (
            self.instance_count == other.instance_count
            and self.target == other.target
            and self.is_local == other.is_local
            and self.location == other.location
            and self.instance_type == other.instance_type
        )

    def __ne__(self, other: object) -> bool:
        if not isinstance(other, ComputeConfiguration):
            return NotImplemented
        return not self.__eq__(other)