about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py479
1 files changed, 479 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
new file mode 100644
index 00000000..bd752571
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py
@@ -0,0 +1,479 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+# enable protected access for protected helper functions
+
+import copy
+from collections import OrderedDict
+from enum import Enum as PyEnum
+from enum import EnumMeta
+from inspect import Parameter, getmro, signature
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
+
+from typing_extensions import Annotated, Literal, TypeAlias
+
+from azure.ai.ml.constants._component import IOConstants
+from azure.ai.ml.exceptions import UserErrorException
+
+# avoid circular import error
+if TYPE_CHECKING:
+    from .input import Input
+    from .output import Output
+
+SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys())
+
+Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None]  # type: ignore
+
+
+def is_group(obj: object) -> bool:
+    """Return True if obj is a group or an instance of a parameter group class.
+
+    :param obj: The object to check.
+    :type obj: Any
+    :return: True if obj is a group or an instance, False otherwise.
+    :rtype: bool
+    """
+    return hasattr(obj, IOConstants.GROUP_ATTR_NAME)
+
+
+def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]:
+    # TODO: we'd better remove this potential recursive import
+    from .enum_input import EnumInput
+    from .input import Input
+
+    annotation: Any = None
+
+    def _is_dataset(data: Any) -> bool:
+        from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
+
+        DATASET_TYPES = JobIOMixin
+        return isinstance(data, DATASET_TYPES)
+
+    if _is_dataset(val):
+        annotation = Input
+    elif val is Parameter.empty or val is None:
+        # If no default value or default is None, create val as the basic parameter type,
+        # it could be replaced using component parameter definition.
+        annotation = Input._get_default_unknown_input()
+    elif isinstance(val, PyEnum):
+        # Handle enum values
+        annotation = EnumInput(enum=val.__class__)
+    else:
+        _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False)
+        if not _new_annotation:
+            # Fall back to default
+            annotation = Input._get_default_unknown_input()
+        else:
+            return _new_annotation
+    return cast(Union["Input", Type["Input"]], annotation)
+
+
+def _get_annotation_cls_by_type(
+    t: type, raise_error: bool = False, optional: Optional[bool] = None
+) -> Optional["Input"]:
+    # TODO: we'd better remove this potential recursive import
+    from .input import Input
+
+    cls = Input._get_input_by_type(t, optional=optional)
+    if cls is None and raise_error:
+        raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input")
+    return cls
+
+
+# pylint: disable=too-many-statements
+def _get_param_with_standard_annotation(
+    cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None
+) -> Dict[str, Union[Annotation, "Input", "Output"]]:
+    """Standardize function parameters or class fields with dsl.types annotation.
+
+    :param cls_or_func: Either a class or a function
+    :type cls_or_func: Union[Callable, Type]
+    :param is_func: Whether `cls_or_func` is a function. Defaults to False.
+    :type is_func: bool
+    :param skip_params:
+    :type skip_params: Optional[List[str]]
+    :return: A dictionary of field annotations
+    :rtype: Dict[str, Union[Annotation, "Input", "Output"]]
+    """
+    # TODO: we'd better remove this potential recursive import
+    from .group_input import GroupInput
+    from .input import Input
+    from .output import Output
+
+    def _is_dsl_type_cls(t: Any) -> bool:
+        if type(t) is not type:  # pylint: disable=unidiomatic-typecheck
+            return False
+        return issubclass(t, (Input, Output))
+
+    def _is_dsl_types(o: object) -> bool:
+        return _is_dsl_type_cls(type(o))
+
+    def _get_fields(annotations: Dict) -> Dict:
+        """Return field names to annotations mapping in class.
+
+        :param annotations: The annotations
+        :type annotations: Dict[str, Union[Annotation, Input, Output]]
+        :return: The field dict
+        :rtype: Dict[str, Union[Annotation, Input, Output]]
+        """
+        annotation_fields = OrderedDict()
+        for name, annotation in annotations.items():
+            # Skip return type
+            if name == "return":
+                continue
+            # Handle EnumMeta annotation
+            if isinstance(annotation, EnumMeta):
+                from .enum_input import EnumInput
+
+                annotation = EnumInput(type="string", enum=annotation)
+            # Handle Group annotation
+            if is_group(annotation):
+                _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME))
+                annotation = _deep_copy
+            # Try creating annotation by type when got like 'param: int'
+            if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation):
+                origin_annotation = annotation
+                annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False))
+                if not annotation:
+                    msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}."
+                    raise UserErrorException(msg)
+            annotation_fields[name] = annotation
+        return annotation_fields
+
+    def _merge_field_keys(
+        annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
+    ) -> List[str]:
+        """Merge field keys from annotations and cls dict to get all fields in class.
+
+        :param annotation_fields: The field annotations
+        :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
+        :param defaults_dict: The map of variable name to default value
+        :type defaults_dict: Dict[str, Any]
+        :return: A list of field keys
+        :rtype: List[str]
+        """
+        anno_keys = list(annotation_fields.keys())
+        dict_keys = defaults_dict.keys()
+        if not dict_keys:
+            return anno_keys
+        return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]]
+
+    def _update_annotation_with_default(
+        anno: Union[Annotation, Input, Output], name: str, default: Any
+    ) -> Union[Annotation, Input, Output]:
+        """Create annotation if is type class and update the default.
+
+        :param anno: The annotation
+        :type anno: Union[Annotation, Input, Output]
+        :param name: The port name
+        :type name: str
+        :param default: The default value
+        :type default: Any
+        :return: The updated annotation
+        :rtype: Union[Annotation, Input, Output]
+        """
+        # Create instance if is type class
+        complete_annotation = anno
+        if _is_dsl_type_cls(anno):
+            if anno is not None and not isinstance(anno, (str, Input, Output)):
+                complete_annotation = anno()
+        if complete_annotation is not None and not isinstance(complete_annotation, str):
+            complete_annotation._port_name = name
+        if default is Input._EMPTY:
+            return complete_annotation
+        if isinstance(complete_annotation, Input):
+            # Non-parameter Input has no default attribute
+            if complete_annotation._is_primitive_type and complete_annotation.default is not None:
+                # logger.warning(
+                #     f"Warning: Default value of f{complete_annotation.name!r} is set twice: "
+                #     f"{complete_annotation.default!r} and {default!r}, will use {default!r}"
+                # )
+                pass
+            complete_annotation._update_default(default)
+        if isinstance(complete_annotation, Output) and default is not None:
+            msg = (
+                f"Default value of Output {complete_annotation._port_name!r} cannot be set:"
+                f"Output has no default value."
+            )
+            raise UserErrorException(msg)
+        return complete_annotation
+
+    def _update_fields_with_default(
+        annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any]
+    ) -> Dict[str, Union[Annotation, Input, Output]]:
+        """Use public values in class dict to update annotations.
+
+        :param annotation_fields: The field annotations
+        :type annotation_fields: Dict[str, Union[Annotation, Input, Output]]
+        :param defaults_dict: A map of variable name to default value
+        :type defaults_dict: Dict[str, Any]
+        :return: List of field names
+        :rtype: List[str]
+        """
+        all_fields = OrderedDict()
+        all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict)
+        for name in all_filed_keys:
+            # Get or create annotation
+            annotation = (
+                annotation_fields[name]
+                if name in annotation_fields
+                else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY))
+            )
+            # Create annotation if is class type and update default
+            annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY))
+            all_fields[name] = annotation
+        return all_fields
+
+    def _merge_and_reorder(
+        inherited_fields: Dict[str, Union[Annotation, Input, Output]],
+        cls_fields: Dict[str, Union[Annotation, Input, Output]],
+    ) -> Dict[str, Union[Annotation, Input, Output]]:
+        """Merge inherited fields with cls fields.
+
+        The order inside each part will not be changed. Order will be:
+
+        {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}.
+
+
+        :param inherited_fields: The inherited fields
+        :type inherited_fields: Dict[str, Union[Annotation, Input, Output]]
+        :param cls_fields: The class fields
+        :type cls_fields: Dict[str, Union[Annotation, Input, Output]]
+        :return: The merged fields
+        :rtype: Dict[str, Union[Annotation, Input, Output]]
+
+        .. admonition:: Additional Note
+
+           :class: note
+
+           If cls overwrite an inherited no default field with default, it will be put in the
+           cls_default_fields part and deleted from inherited_no_default_fields:
+
+           .. code-block:: python
+
+              @dsl.group
+              class SubGroup:
+                  int_param0: Integer
+                  int_param1: int
+
+              @dsl.group
+              class Group(SubGroup):
+                  int_param3: Integer
+                  int_param1: int = 1
+
+           The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`.
+        """
+
+        def _split(
+            _fields: Dict[str, Union[Annotation, Input, Output]]
+        ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]:
+            """Split fields to two parts from the first default field.
+
+            :param _fields: The fields
+            :type _fields: Dict[str, Union[Annotation, Input, Output]]
+            :return: A 2-tuple of (fields with no defaults, fields with defaults)
+            :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]
+            """
+            _no_defaults_fields, _defaults_fields = {}, {}
+            seen_default = False
+            for key, val in _fields.items():
+                if val is not None and not isinstance(val, str):
+                    if val.get("default", None) or seen_default:
+                        seen_default = True
+                        _defaults_fields[key] = val
+                    else:
+                        _no_defaults_fields[key] = val
+            return _no_defaults_fields, _defaults_fields
+
+        inherited_no_default, inherited_default = _split(inherited_fields)
+        cls_no_default, cls_default = _split(cls_fields)
+        # Cross comparison and delete from inherited_fields if same key appeared in cls_fields
+        # pylint: disable=consider-iterating-dictionary
+        for key in cls_default.keys():
+            if key in inherited_no_default.keys():
+                del inherited_no_default[key]
+        for key in cls_no_default.keys():
+            if key in inherited_default.keys():
+                del inherited_default[key]
+        return OrderedDict(
+            {
+                **inherited_no_default,
+                **cls_no_default,
+                **inherited_default,
+                **cls_default,
+            }
+        )
+
+    def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]:
+        """Get all fields inherited from @group decorated base classes.
+
+        :return: The field dict
+        :rtype: Dict[str, Union[Annotation, Input, Output]]
+        """
+        # Return value of _get_param_with_standard_annotation
+        _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({})
+        if is_func:
+            return _fields
+        # In reversed order so that more derived classes
+        # override earlier field definitions in base classes.
+        if isinstance(cls_or_func, type):
+            for base in cls_or_func.__mro__[-1:0:-1]:
+                if is_group(base):
+                    # merge and reorder fields from current base with previous
+                    _fields = _merge_and_reorder(
+                        _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values)
+                    )
+        return _fields
+
+    skip_params = skip_params or []
+    inherited_fields = _get_inherited_fields()
+    # From annotations get field with type
+    annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {})
+    annotations = {k: v for k, v in annotations.items() if k not in skip_params}
+    annotations = _update_io_from_mldesigner(annotations)
+    annotation_fields = _get_fields(annotations)
+    defaults_dict: Dict[str, Any] = {}
+    # Update fields use class field with defaults from class dict or signature(func).paramters
+    if not is_func:
+        # Only consider public fields in class dict
+        defaults_dict = {
+            key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params
+        }
+    else:
+        # Infer parameter type from value if is function
+        defaults_dict = {
+            key: val.default
+            for key, val in signature(cls_or_func).parameters.items()
+            if key not in skip_params and val.kind != val.VAR_KEYWORD
+        }
+    fields = _update_fields_with_default(annotation_fields, defaults_dict)
+    all_fields = _merge_and_reorder(inherited_fields, fields)
+    return all_fields
+
+
+def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]:
+    """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output.
+
+    This function depends on:
+
+    * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations
+      to IO entities.
+    * class names of `mldesigner._input_output` to translate Input/Output class annotations
+      to IO entities.
+
+    :param annotations: A map of variable names to annotations
+    :type annotations: Dict[str, Annotation]
+    :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output
+    :rtype: Dict[str, Union[Annotation, Input, Output]]
+    """
+    from typing_extensions import get_args, get_origin
+
+    from azure.ai.ml import Input, Output
+
+    from .enum_input import EnumInput
+
+    mldesigner_pkg = "mldesigner"
+    param_name = "_Param"
+    return_annotation_key = "return"
+
+    def _is_primitive_type(io: type) -> bool:
+        """Checks whether type is a primitive type
+
+        :param io: A type
+        :type io: type
+        :return: Return true if type is subclass of mldesigner._input_output._Param
+        :rtype: bool
+        """
+        return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io))
+
+    def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool:
+        """Checks whether a type is an Input or Output type
+
+        :param io: A type
+        :type io: type
+        :param type_str: The kind of type to check for
+        :type type_str: Literal["Input", "Output", "Meta"]
+        :return: Return true if type name contains type_str
+        :rtype: bool
+        """
+        if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg):
+            if type_str in io.__name__:
+                return True
+        return False
+
+    result = {}
+    for key, io in annotations.items():  # pylint: disable=too-many-nested-blocks
+        if isinstance(io, type):
+            if _is_input_or_output_type(io, "Input"):
+                # mldesigner.Input -> entities.Input
+                io = Input
+            elif _is_input_or_output_type(io, "Output"):
+                # mldesigner.Output -> entities.Output
+                io = Output
+            elif _is_primitive_type(io):
+                io = (
+                    Output(type=io.TYPE_NAME)  # type: ignore
+                    if key == return_annotation_key
+                    else Input(type=io.TYPE_NAME)  # type: ignore
+                )
+        elif hasattr(io, "_to_io_entity_args_dict"):
+            try:
+                if _is_input_or_output_type(type(io), "Input"):
+                    # mldesigner.Input() -> entities.Input()
+                    if io is not None:
+                        io = Input(**io._to_io_entity_args_dict())
+                elif _is_input_or_output_type(type(io), "Output"):
+                    # mldesigner.Output() -> entities.Output()
+                    if io is not None:
+                        io = Output(**io._to_io_entity_args_dict())
+                elif _is_primitive_type(type(io)):
+                    if io is not None and not isinstance(io, str):
+                        if io._is_enum():
+                            io = EnumInput(**io._to_io_entity_args_dict())
+                        else:
+                            io = (
+                                Output(**io._to_io_entity_args_dict())
+                                if key == return_annotation_key
+                                else Input(**io._to_io_entity_args_dict())
+                            )
+            except BaseException as e:
+                raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e
+                # Handle Annotated annotation
+        elif get_origin(io) is Annotated:
+            hint_type, arg, *hint_args = get_args(io)  # pylint: disable=unused-variable
+            if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE:
+                if not _is_input_or_output_type(type(arg), "Meta"):
+                    raise UserErrorException(
+                        f"Annotated Metadata class only support "
+                        f"mldesigner._input_output.Meta, "
+                        f"it is {type(arg)} now."
+                    )
+                if arg.type is not None and arg.type != hint_type:
+                    raise UserErrorException(
+                        f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}"
+                    )
+                arg.type = hint_type
+                io = (
+                    Output(**arg._to_io_entity_args_dict())
+                    if key == return_annotation_key
+                    else Input(**arg._to_io_entity_args_dict())
+                )
+        result[key] = io
+    return result
+
+
+def _remove_empty_values(data: Any) -> Any:
+    """Recursively removes None values from a dict
+
+    :param data: The value to remove None from
+    :type data: T
+    :return:
+      * `data` if `data` is not a dict
+      * `data` with None values recursively filtered out if data is a dict
+    :rtype: T
+    """
+    if not isinstance(data, dict):
+        return data
+    return {k: _remove_empty_values(v) for k, v in data.items() if v is not None}