about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
blob: bd75257116353cdef81b7965e1fbf17387423020 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access
# enable protected access for protected helper functions

import copy
from collections import OrderedDict
from enum import Enum as PyEnum
from enum import EnumMeta
from inspect import Parameter, getmro, signature
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from typing_extensions import Annotated, Literal, TypeAlias

from azure.ai.ml.constants._component import IOConstants
from azure.ai.ml.exceptions import UserErrorException

# avoid circular import error
if TYPE_CHECKING:
    from .input import Input
    from .output import Output

SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys())

Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None]  # type: ignore


def is_group(obj: object) -> bool:
    """Return True if obj is a group or an instance of a parameter group class.

    :param obj: The object to check.
    :type obj: Any
    :return: True if obj is a group or an instance, False otherwise.
    :rtype: bool
    """
    return hasattr(obj, IOConstants.GROUP_ATTR_NAME)


def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]:
    # TODO: we'd better remove this potential recursive import
    from .enum_input import EnumInput
    from .input import Input

    annotation: Any = None

    def _is_dataset(data: Any) -> bool:
        from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin

        DATASET_TYPES = JobIOMixin
        return isinstance(data, DATASET_TYPES)

    if _is_dataset(val):
        annotation = Input
    elif val is Parameter.empty or val is None:
        # If no default value or default is None, create val as the basic parameter type,
        # it could be replaced using component parameter definition.
        annotation = Input._get_default_unknown_input()
    elif isinstance(val, PyEnum):
        # Handle enum values
        annotation = EnumInput(enum=val.__class__)
    else:
        _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False)
        if not _new_annotation:
            # Fall back to default
            annotation = Input._get_default_unknown_input()
        else:
            return _new_annotation
    return cast(Union["Input", Type["Input"]], annotation)


def _get_annotation_cls_by_type(
    t: type, raise_error: bool = False, optional: Optional[bool] = None
) -> Optional["Input"]:
    # TODO: we'd better remove this potential recursive import
    from .input import Input

    cls = Input._get_input_by_type(t, optional=optional)
    if cls is None and raise_error:
        raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input")
    return cls


# pylint: disable=too-many-statements
def _get_param_with_standard_annotation(
    cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None
) -> Dict[str, Union[Annotation, "Input", "Output"]]:
    """Standardize function parameters or class fields with dsl.types annotation.

    :param cls_or_func: Either a class or a function
    :type cls_or_func: Union[Callable, Type]
    :param is_func: Whether `cls_or_func` is a function. Defaults to False.
    :type is_func: bool
    :param skip_params:
    :type skip_params: Optional[List[str]]
    :return: A dictionary of field annotations
    :rtype: Dict[str, Union[Annotation, "Input", "Output"]]
    """
    # TODO: we'd better remove this potential recursive import
    from .group_input import GroupInput
    from .input import Input
    from .output import Output

    def _is_dsl_type_cls(t: Any) -> bool:
        if type(t) is not type:  # pylint: disable=unidiomatic-typecheck
            return False
        return issubclass(t, (Input, Output))

    def _is_dsl_types(o: object) -> bool:
        return _is_dsl_type_cls(type(o))

    def _get_fields(annotations: Dict) -> Dict:
        """Return field names to annotations mapping in class.

        :param annotations: The annotations
        :type annotations: Dict[str, Union[Annotation, Input, Output]]
        :return: The field dict
        :rtype: Dict[str, Union[Annotation, Input, Output]]
        """
        annotation_fields = OrderedDict()
        for name, annotation in annotations.items():
            # Skip return type
            if name == "return":
                continue
            # Handle EnumMeta annotation
            if isinstance(annotation, EnumMeta):
                from .enum_input import EnumInput

                annotation = EnumInput(type="string", enum=annotation)
            # Handle Group annotation
            if is_group(annotation):
                _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME))
                annotation = _deep_copy
            # Try creating annotation by type when got like 'param: int'
            if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation):
                origin_annotation = annotation
                annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False))
                if not annotation:
                    msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}."
                    raise UserErrorException(msg)
            annotation_fields[name] = annotation
        return annotation_fields

    def _merge_field_keys(
        annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
    ) -> List[str]:
        """Merge field keys from annotations and cls dict to get all fields in class.

        :param annotation_fields: The field annotations
        :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
        :param defaults_dict: The map of variable name to default value
        :type defaults_dict: Dict[str, Any]
        :return: A list of field keys
        :rtype: List[str]
        """
        anno_keys = list(annotation_fields.keys())
        dict_keys = defaults_dict.keys()
        if not dict_keys:
            return anno_keys
        return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]]

    def _update_annotation_with_default(
        anno: Union[Annotation, Input, Output], name: str, default: Any
    ) -> Union[Annotation, Input, Output]:
        """Create annotation if is type class and update the default.

        :param anno: The annotation
        :type anno: Union[Annotation, Input, Output]
        :param name: The port name
        :type name: str
        :param default: The default value
        :type default: Any
        :return: The updated annotation
        :rtype: Union[Annotation, Input, Output]
        """
        # Create instance if is type class
        complete_annotation = anno
        if _is_dsl_type_cls(anno):
            if anno is not None and not isinstance(anno, (str, Input, Output)):
                complete_annotation = anno()
        if complete_annotation is not None and not isinstance(complete_annotation, str):
            complete_annotation._port_name = name
        if default is Input._EMPTY:
            return complete_annotation
        if isinstance(complete_annotation, Input):
            # Non-parameter Input has no default attribute
            if complete_annotation._is_primitive_type and complete_annotation.default is not None:
                # logger.warning(
                #     f"Warning: Default value of f{complete_annotation.name!r} is set twice: "
                #     f"{complete_annotation.default!r} and {default!r}, will use {default!r}"
                # )
                pass
            complete_annotation._update_default(default)
        if isinstance(complete_annotation, Output) and default is not None:
            msg = (
                f"Default value of Output {complete_annotation._port_name!r} cannot be set:"
                f"Output has no default value."
            )
            raise UserErrorException(msg)
        return complete_annotation

    def _update_fields_with_default(
        annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
    ) -> Dict[str, Union[Annotation, Input, Output]]:
        """Use public values in class dict to update annotations.

        :param annotation_fields: The field annotations
        :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
        :param defaults_dict: A map of variable name to default value
        :type defaults_dict: Dict[str, Any]
        :return: List of field names
        :rtype: List[str]
        """
        all_fields = OrderedDict()
        all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict)
        for name in all_filed_keys:
            # Get or create annotation
            annotation = (
                annotation_fields[name]
                if name in annotation_fields
                else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY))
            )
            # Create annotation if is class type and update default
            annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY))
            all_fields[name] = annotation
        return all_fields

    def _merge_and_reorder(
        inherited_fields: Dict[str, Union[Annotation, Input, Output]],
        cls_fields: Dict[str, Union[Annotation, Input, Output]],
    ) -> Dict[str, Union[Annotation, Input, Output]]:
        """Merge inherited fields with cls fields.

        The order inside each part will not be changed. Order will be:

        {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}.


        :param inherited_fields: The inherited fields
        :type inherited_fields: Dict[str, Union[Annotation, Input, Output]]
        :param cls_fields: The class fields
        :type cls_fields: Dict[str, Union[Annotation, Input, Output]]
        :return: The merged fields
        :rtype: Dict[str, Union[Annotation, Input, Output]]

        .. admonition:: Additional Note

           :class: note

           If cls overwrite an inherited no default field with default, it will be put in the
           cls_default_fields part and deleted from inherited_no_default_fields:

           .. code-block:: python

              @dsl.group
              class SubGroup:
                  int_param0: Integer
                  int_param1: int

              @dsl.group
              class Group(SubGroup):
                  int_param3: Integer
                  int_param1: int = 1

           The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`.
        """

        def _split(
            _fields: Dict[str, Union[Annotation, Input, Output]]
        ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]:
            """Split fields to two parts from the first default field.

            :param _fields: The fields
            :type _fields: Dict[str, Union[Annotation, Input, Output]]
            :return: A 2-tuple of (fields with no defaults, fields with defaults)
            :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]
            """
            _no_defaults_fields, _defaults_fields = {}, {}
            seen_default = False
            for key, val in _fields.items():
                if val is not None and not isinstance(val, str):
                    if val.get("default", None) or seen_default:
                        seen_default = True
                        _defaults_fields[key] = val
                    else:
                        _no_defaults_fields[key] = val
            return _no_defaults_fields, _defaults_fields

        inherited_no_default, inherited_default = _split(inherited_fields)
        cls_no_default, cls_default = _split(cls_fields)
        # Cross comparison and delete from inherited_fields if same key appeared in cls_fields
        # pylint: disable=consider-iterating-dictionary
        for key in cls_default.keys():
            if key in inherited_no_default.keys():
                del inherited_no_default[key]
        for key in cls_no_default.keys():
            if key in inherited_default.keys():
                del inherited_default[key]
        return OrderedDict(
            {
                **inherited_no_default,
                **cls_no_default,
                **inherited_default,
                **cls_default,
            }
        )

    def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]:
        """Get all fields inherited from @group decorated base classes.

        :return: The field dict
        :rtype: Dict[str, Union[Annotation, Input, Output]]
        """
        # Return value of _get_param_with_standard_annotation
        _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({})
        if is_func:
            return _fields
        # In reversed order so that more derived classes
        # override earlier field definitions in base classes.
        if isinstance(cls_or_func, type):
            for base in cls_or_func.__mro__[-1:0:-1]:
                if is_group(base):
                    # merge and reorder fields from current base with previous
                    _fields = _merge_and_reorder(
                        _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values)
                    )
        return _fields

    skip_params = skip_params or []
    inherited_fields = _get_inherited_fields()
    # From annotations get field with type
    annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {})
    annotations = {k: v for k, v in annotations.items() if k not in skip_params}
    annotations = _update_io_from_mldesigner(annotations)
    annotation_fields = _get_fields(annotations)
    defaults_dict: Dict[str, Any] = {}
    # Update fields use class field with defaults from class dict or signature(func).paramters
    if not is_func:
        # Only consider public fields in class dict
        defaults_dict = {
            key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params
        }
    else:
        # Infer parameter type from value if is function
        defaults_dict = {
            key: val.default
            for key, val in signature(cls_or_func).parameters.items()
            if key not in skip_params and val.kind != val.VAR_KEYWORD
        }
    fields = _update_fields_with_default(annotation_fields, defaults_dict)
    all_fields = _merge_and_reorder(inherited_fields, fields)
    return all_fields


def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]:
    """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output.

    This function depends on:

    * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations
      to IO entities.
    * class names of `mldesigner._input_output` to translate Input/Output class annotations
      to IO entities.

    :param annotations: A map of variable names to annotations
    :type annotations: Dict[str, Annotation]
    :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output
    :rtype: Dict[str, Union[Annotation, Input, Output]]
    """
    from typing_extensions import get_args, get_origin

    from azure.ai.ml import Input, Output

    from .enum_input import EnumInput

    mldesigner_pkg = "mldesigner"
    param_name = "_Param"
    return_annotation_key = "return"

    def _is_primitive_type(io: type) -> bool:
        """Checks whether type is a primitive type

        :param io: A type
        :type io: type
        :return: Return true if type is subclass of mldesigner._input_output._Param
        :rtype: bool
        """
        return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io))

    def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool:
        """Checks whether a type is an Input or Output type

        :param io: A type
        :type io: type
        :param type_str: The kind of type to check for
        :type type_str: Literal["Input", "Output", "Meta"]
        :return: Return true if type name contains type_str
        :rtype: bool
        """
        if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg):
            if type_str in io.__name__:
                return True
        return False

    result = {}
    for key, io in annotations.items():  # pylint: disable=too-many-nested-blocks
        if isinstance(io, type):
            if _is_input_or_output_type(io, "Input"):
                # mldesigner.Input -> entities.Input
                io = Input
            elif _is_input_or_output_type(io, "Output"):
                # mldesigner.Output -> entities.Output
                io = Output
            elif _is_primitive_type(io):
                io = (
                    Output(type=io.TYPE_NAME)  # type: ignore
                    if key == return_annotation_key
                    else Input(type=io.TYPE_NAME)  # type: ignore
                )
        elif hasattr(io, "_to_io_entity_args_dict"):
            try:
                if _is_input_or_output_type(type(io), "Input"):
                    # mldesigner.Input() -> entities.Input()
                    if io is not None:
                        io = Input(**io._to_io_entity_args_dict())
                elif _is_input_or_output_type(type(io), "Output"):
                    # mldesigner.Output() -> entities.Output()
                    if io is not None:
                        io = Output(**io._to_io_entity_args_dict())
                elif _is_primitive_type(type(io)):
                    if io is not None and not isinstance(io, str):
                        if io._is_enum():
                            io = EnumInput(**io._to_io_entity_args_dict())
                        else:
                            io = (
                                Output(**io._to_io_entity_args_dict())
                                if key == return_annotation_key
                                else Input(**io._to_io_entity_args_dict())
                            )
            except BaseException as e:
                raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e
                # Handle Annotated annotation
        elif get_origin(io) is Annotated:
            hint_type, arg, *hint_args = get_args(io)  # pylint: disable=unused-variable
            if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE:
                if not _is_input_or_output_type(type(arg), "Meta"):
                    raise UserErrorException(
                        f"Annotated Metadata class only support "
                        f"mldesigner._input_output.Meta, "
                        f"it is {type(arg)} now."
                    )
                if arg.type is not None and arg.type != hint_type:
                    raise UserErrorException(
                        f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}"
                    )
                arg.type = hint_type
                io = (
                    Output(**arg._to_io_entity_args_dict())
                    if key == return_annotation_key
                    else Input(**arg._to_io_entity_args_dict())
                )
        result[key] = io
    return result


def _remove_empty_values(data: Any) -> Any:
    """Recursively removes None values from a dict

    :param data: The value to remove None from
    :type data: T
    :return:
      * `data` if `data` is not a dict
      * `data` with None values recursively filtered out if data is a dict
    :rtype: T
    """
    if not isinstance(data, dict):
        return data
    return {k: _remove_empty_values(v) for k, v in data.items() if v is not None}