diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py')
| -rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py | 479 |
1 files changed, 479 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py new file mode 100644 index 00000000..bd752571 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +# enable protected access for protected helper functions + +import copy +from collections import OrderedDict +from enum import Enum as PyEnum +from enum import EnumMeta +from inspect import Parameter, getmro, signature +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + +from typing_extensions import Annotated, Literal, TypeAlias + +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.exceptions import UserErrorException + +# avoid circular import error +if TYPE_CHECKING: + from .input import Input + from .output import Output + +SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys()) + +Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None] # type: ignore + + +def is_group(obj: object) -> bool: + """Return True if obj is a group or an instance of a parameter group class. + + :param obj: The object to check. + :type obj: Any + :return: True if obj is a group or an instance, False otherwise. + :rtype: bool + """ + return hasattr(obj, IOConstants.GROUP_ATTR_NAME) + + +def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]: + # TODO: we'd better remove this potential recursive import + from .enum_input import EnumInput + from .input import Input + + annotation: Any = None + + def _is_dataset(data: Any) -> bool: + from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin + + DATASET_TYPES = JobIOMixin + return isinstance(data, DATASET_TYPES) + + if _is_dataset(val): + annotation = Input + elif val is Parameter.empty or val is None: + # If no default value or default is None, create val as the basic parameter type, + # it could be replaced using component parameter definition. + annotation = Input._get_default_unknown_input() + elif isinstance(val, PyEnum): + # Handle enum values + annotation = EnumInput(enum=val.__class__) + else: + _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False) + if not _new_annotation: + # Fall back to default + annotation = Input._get_default_unknown_input() + else: + return _new_annotation + return cast(Union["Input", Type["Input"]], annotation) + + +def _get_annotation_cls_by_type( + t: type, raise_error: bool = False, optional: Optional[bool] = None +) -> Optional["Input"]: + # TODO: we'd better remove this potential recursive import + from .input import Input + + cls = Input._get_input_by_type(t, optional=optional) + if cls is None and raise_error: + raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input") + return cls + + +# pylint: disable=too-many-statements +def _get_param_with_standard_annotation( + cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None +) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Standardize function parameters or class fields with dsl.types annotation. + + :param cls_or_func: Either a class or a function + :type cls_or_func: Union[Callable, Type] + :param is_func: Whether `cls_or_func` is a function. Defaults to False. + :type is_func: bool + :param skip_params: + :type skip_params: Optional[List[str]] + :return: A dictionary of field annotations + :rtype: Dict[str, Union[Annotation, "Input", "Output"]] + """ + # TODO: we'd better remove this potential recursive import + from .group_input import GroupInput + from .input import Input + from .output import Output + + def _is_dsl_type_cls(t: Any) -> bool: + if type(t) is not type: # pylint: disable=unidiomatic-typecheck + return False + return issubclass(t, (Input, Output)) + + def _is_dsl_types(o: object) -> bool: + return _is_dsl_type_cls(type(o)) + + def _get_fields(annotations: Dict) -> Dict: + """Return field names to annotations mapping in class. + + :param annotations: The annotations + :type annotations: Dict[str, Union[Annotation, Input, Output]] + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + annotation_fields = OrderedDict() + for name, annotation in annotations.items(): + # Skip return type + if name == "return": + continue + # Handle EnumMeta annotation + if isinstance(annotation, EnumMeta): + from .enum_input import EnumInput + + annotation = EnumInput(type="string", enum=annotation) + # Handle Group annotation + if is_group(annotation): + _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME)) + annotation = _deep_copy + # Try creating annotation by type when got like 'param: int' + if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation): + origin_annotation = annotation + annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False)) + if not annotation: + msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}." + raise UserErrorException(msg) + annotation_fields[name] = annotation + return annotation_fields + + def _merge_field_keys( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> List[str]: + """Merge field keys from annotations and cls dict to get all fields in class. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: The map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: A list of field keys + :rtype: List[str] + """ + anno_keys = list(annotation_fields.keys()) + dict_keys = defaults_dict.keys() + if not dict_keys: + return anno_keys + return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]] + + def _update_annotation_with_default( + anno: Union[Annotation, Input, Output], name: str, default: Any + ) -> Union[Annotation, Input, Output]: + """Create annotation if is type class and update the default. + + :param anno: The annotation + :type anno: Union[Annotation, Input, Output] + :param name: The port name + :type name: str + :param default: The default value + :type default: Any + :return: The updated annotation + :rtype: Union[Annotation, Input, Output] + """ + # Create instance if is type class + complete_annotation = anno + if _is_dsl_type_cls(anno): + if anno is not None and not isinstance(anno, (str, Input, Output)): + complete_annotation = anno() + if complete_annotation is not None and not isinstance(complete_annotation, str): + complete_annotation._port_name = name + if default is Input._EMPTY: + return complete_annotation + if isinstance(complete_annotation, Input): + # Non-parameter Input has no default attribute + if complete_annotation._is_primitive_type and complete_annotation.default is not None: + # logger.warning( + # f"Warning: Default value of f{complete_annotation.name!r} is set twice: " + # f"{complete_annotation.default!r} and {default!r}, will use {default!r}" + # ) + pass + complete_annotation._update_default(default) + if isinstance(complete_annotation, Output) and default is not None: + msg = ( + f"Default value of Output {complete_annotation._port_name!r} cannot be set:" + f"Output has no default value." + ) + raise UserErrorException(msg) + return complete_annotation + + def _update_fields_with_default( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Use public values in class dict to update annotations. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: A map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: List of field names + :rtype: List[str] + """ + all_fields = OrderedDict() + all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict) + for name in all_filed_keys: + # Get or create annotation + annotation = ( + annotation_fields[name] + if name in annotation_fields + else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY)) + ) + # Create annotation if is class type and update default + annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY)) + all_fields[name] = annotation + return all_fields + + def _merge_and_reorder( + inherited_fields: Dict[str, Union[Annotation, Input, Output]], + cls_fields: Dict[str, Union[Annotation, Input, Output]], + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Merge inherited fields with cls fields. + + The order inside each part will not be changed. Order will be: + + {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}. + + + :param inherited_fields: The inherited fields + :type inherited_fields: Dict[str, Union[Annotation, Input, Output]] + :param cls_fields: The class fields + :type cls_fields: Dict[str, Union[Annotation, Input, Output]] + :return: The merged fields + :rtype: Dict[str, Union[Annotation, Input, Output]] + + .. admonition:: Additional Note + + :class: note + + If cls overwrite an inherited no default field with default, it will be put in the + cls_default_fields part and deleted from inherited_no_default_fields: + + .. code-block:: python + + @dsl.group + class SubGroup: + int_param0: Integer + int_param1: int + + @dsl.group + class Group(SubGroup): + int_param3: Integer + int_param1: int = 1 + + The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`. + """ + + def _split( + _fields: Dict[str, Union[Annotation, Input, Output]] + ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]: + """Split fields to two parts from the first default field. + + :param _fields: The fields + :type _fields: Dict[str, Union[Annotation, Input, Output]] + :return: A 2-tuple of (fields with no defaults, fields with defaults) + :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]] + """ + _no_defaults_fields, _defaults_fields = {}, {} + seen_default = False + for key, val in _fields.items(): + if val is not None and not isinstance(val, str): + if val.get("default", None) or seen_default: + seen_default = True + _defaults_fields[key] = val + else: + _no_defaults_fields[key] = val + return _no_defaults_fields, _defaults_fields + + inherited_no_default, inherited_default = _split(inherited_fields) + cls_no_default, cls_default = _split(cls_fields) + # Cross comparison and delete from inherited_fields if same key appeared in cls_fields + # pylint: disable=consider-iterating-dictionary + for key in cls_default.keys(): + if key in inherited_no_default.keys(): + del inherited_no_default[key] + for key in cls_no_default.keys(): + if key in inherited_default.keys(): + del inherited_default[key] + return OrderedDict( + { + **inherited_no_default, + **cls_no_default, + **inherited_default, + **cls_default, + } + ) + + def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]: + """Get all fields inherited from @group decorated base classes. + + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + # Return value of _get_param_with_standard_annotation + _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({}) + if is_func: + return _fields + # In reversed order so that more derived classes + # override earlier field definitions in base classes. + if isinstance(cls_or_func, type): + for base in cls_or_func.__mro__[-1:0:-1]: + if is_group(base): + # merge and reorder fields from current base with previous + _fields = _merge_and_reorder( + _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values) + ) + return _fields + + skip_params = skip_params or [] + inherited_fields = _get_inherited_fields() + # From annotations get field with type + annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {}) + annotations = {k: v for k, v in annotations.items() if k not in skip_params} + annotations = _update_io_from_mldesigner(annotations) + annotation_fields = _get_fields(annotations) + defaults_dict: Dict[str, Any] = {} + # Update fields use class field with defaults from class dict or signature(func).paramters + if not is_func: + # Only consider public fields in class dict + defaults_dict = { + key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params + } + else: + # Infer parameter type from value if is function + defaults_dict = { + key: val.default + for key, val in signature(cls_or_func).parameters.items() + if key not in skip_params and val.kind != val.VAR_KEYWORD + } + fields = _update_fields_with_default(annotation_fields, defaults_dict) + all_fields = _merge_and_reorder(inherited_fields, fields) + return all_fields + + +def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output. + + This function depends on: + + * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations + to IO entities. + * class names of `mldesigner._input_output` to translate Input/Output class annotations + to IO entities. + + :param annotations: A map of variable names to annotations + :type annotations: Dict[str, Annotation] + :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + from typing_extensions import get_args, get_origin + + from azure.ai.ml import Input, Output + + from .enum_input import EnumInput + + mldesigner_pkg = "mldesigner" + param_name = "_Param" + return_annotation_key = "return" + + def _is_primitive_type(io: type) -> bool: + """Checks whether type is a primitive type + + :param io: A type + :type io: type + :return: Return true if type is subclass of mldesigner._input_output._Param + :rtype: bool + """ + return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io)) + + def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool: + """Checks whether a type is an Input or Output type + + :param io: A type + :type io: type + :param type_str: The kind of type to check for + :type type_str: Literal["Input", "Output", "Meta"] + :return: Return true if type name contains type_str + :rtype: bool + """ + if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg): + if type_str in io.__name__: + return True + return False + + result = {} + for key, io in annotations.items(): # pylint: disable=too-many-nested-blocks + if isinstance(io, type): + if _is_input_or_output_type(io, "Input"): + # mldesigner.Input -> entities.Input + io = Input + elif _is_input_or_output_type(io, "Output"): + # mldesigner.Output -> entities.Output + io = Output + elif _is_primitive_type(io): + io = ( + Output(type=io.TYPE_NAME) # type: ignore + if key == return_annotation_key + else Input(type=io.TYPE_NAME) # type: ignore + ) + elif hasattr(io, "_to_io_entity_args_dict"): + try: + if _is_input_or_output_type(type(io), "Input"): + # mldesigner.Input() -> entities.Input() + if io is not None: + io = Input(**io._to_io_entity_args_dict()) + elif _is_input_or_output_type(type(io), "Output"): + # mldesigner.Output() -> entities.Output() + if io is not None: + io = Output(**io._to_io_entity_args_dict()) + elif _is_primitive_type(type(io)): + if io is not None and not isinstance(io, str): + if io._is_enum(): + io = EnumInput(**io._to_io_entity_args_dict()) + else: + io = ( + Output(**io._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**io._to_io_entity_args_dict()) + ) + except BaseException as e: + raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e + # Handle Annotated annotation + elif get_origin(io) is Annotated: + hint_type, arg, *hint_args = get_args(io) # pylint: disable=unused-variable + if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE: + if not _is_input_or_output_type(type(arg), "Meta"): + raise UserErrorException( + f"Annotated Metadata class only support " + f"mldesigner._input_output.Meta, " + f"it is {type(arg)} now." + ) + if arg.type is not None and arg.type != hint_type: + raise UserErrorException( + f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}" + ) + arg.type = hint_type + io = ( + Output(**arg._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**arg._to_io_entity_args_dict()) + ) + result[key] = io + return result + + +def _remove_empty_values(data: Any) -> Any: + """Recursively removes None values from a dict + + :param data: The value to remove None from + :type data: T + :return: + * `data` if `data` is not a dict + * `data` with None values recursively filtered out if data is a dict + :rtype: T + """ + if not isinstance(data, dict): + return data + return {k: _remove_empty_values(v) for k, v in data.items() if v is not None} |
