diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs')
8 files changed, 1904 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py new file mode 100644 index 00000000..90affdda --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/__init__.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +"""This package includes the type classes which could be used in dsl.pipeline, +command function, or any other place that requires job inputs/outputs. + +.. note:: + + The following pseudo-code shows how to create a pipeline with such classes. + + .. code-block:: python + + @pipeline() + def some_pipeline( + input_param: Input(type="uri_folder", path="xxx", mode="ro_mount"), + int_param0: Input(type="integer", default=0, min=-3, max=10), + int_param1 = 2 + str_param = 'abc', + ): + pass + + + The following pseudo-code shows how to create a command with such classes. + + .. code-block:: python + + my_command = command( + name="my_command", + display_name="my_command", + description="This is a command", + tags=dict(), + command="python train.py --input-data ${{inputs.input_data}} --lr ${{inputs.learning_rate}}", + code="./src", + compute="cpu-cluster", + environment="my-env:1", + distribution=MpiDistribution(process_count_per_instance=4), + environment_variables=dict(foo="bar"), + # Customers can still do this: + # resources=Resources(instance_count=2, instance_type="STANDARD_D2"), + # limits=Limits(timeout=300), + inputs={ + "float": Input(type="number", default=1.1, min=0, max=5), + "integer": Input(type="integer", default=2, min=-1, max=4), + "integer1": 2, + "string0": Input(type="string", default="default_str0"), + "string1": "default_str1", + "boolean": Input(type="boolean", default=False), + "uri_folder": Input(type="uri_folder", path="https://my-blob/path/to/data", mode="ro_mount"), + "uri_file": Input(type="uri_file", path="https://my-blob/path/to/data", mode="download"), + }, + outputs={"my_model": Output(type="mlflow_model")}, + ) + node = my_command() +""" + +from .enum_input import EnumInput +from .external_data import Database, FileSystem +from .group_input import GroupInput +from .input import Input +from .output import Output +from .utils import _get_param_with_standard_annotation, is_group + +__all__ = [ + "Input", + "Output", + "EnumInput", + "GroupInput", + "is_group", + "_get_param_with_standard_annotation", + "Database", + "FileSystem", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py new file mode 100644 index 00000000..3a726b38 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/base.py @@ -0,0 +1,34 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any + +from azure.ai.ml._schema.component.input_output import SUPPORTED_PARAM_TYPES +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + + +class _InputOutputBase(DictMixin, RestTranslatableMixin): + def __init__( + self, + *, + # pylint: disable=redefined-builtin + type: Any, + # pylint: disable=unused-argument + **kwargs: Any, + ) -> None: + """Base class for Input & Output class. + + This class is introduced to support literal output in the future. + + :param type: The type of the Input/Output. + :type type: str + """ + self.type = type + + def _is_literal(self) -> bool: + """Check whether input is a literal + + :return: True if this input is literal input. + :rtype: bool + """ + return self.type in SUPPORTED_PARAM_TYPES diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py new file mode 100644 index 00000000..d6c88eef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/enum_input.py @@ -0,0 +1,133 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from enum import EnumMeta +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .input import Input + + +class EnumInput(Input): + """Enum parameter parse the value according to its enum values.""" + + def __init__( + self, + *, + enum: Optional[Union[EnumMeta, Sequence[str]]] = None, + default: Any = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Enum parameter parse the value according to its enum values. + + :param enum: Enum values. + :type enum: Union[EnumMeta, Sequence[str]] + :param default: Default value of the parameter + :type default: Any + :param description: Description of the parameter + :type description: str + """ + enum_values = self._assert_enum_valid(enum) + self._enum_class: Optional[EnumMeta] = None + # This is used to parse enum class instead of enum str value if a enum class is provided. + if isinstance(enum, EnumMeta): + self._enum_class = enum + self._str2enum = dict(zip(enum_values, enum)) + else: + self._str2enum = {v: v for v in enum_values} + super().__init__(type="string", default=default, enum=enum_values, description=description) + + @property + def _allowed_types(self) -> Tuple: + return ( + (str,) + if not self._enum_class + else ( + self._enum_class, + str, + ) + ) + + @classmethod + def _assert_enum_valid(cls, enum: Optional[Union[EnumMeta, Sequence[str]]]) -> List: + """Check whether the enum is valid and return the values of the enum. + + :param enum: The enum to validate + :type enum: Type + :return: The enum values + :rtype: List[Any] + """ + if isinstance(enum, EnumMeta): + enum_values = [str(option.value) for option in enum] # type: ignore + elif isinstance(enum, Iterable): + enum_values = list(enum) + else: + msg = "enum must be a subclass of Enum or an iterable." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if len(enum_values) <= 0: + msg = "enum must have enum values." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if any(not isinstance(v, str) for v in enum_values): + msg = "enum values must be str type." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return enum_values + + def _parse(self, val: str) -> Any: + """Parse the enum value from a string value or the enum value. + + :param val: The string to parse + :type val: str + :return: The enum value + :rtype: Any + """ + if val is None: + return val + + if self._enum_class and isinstance(val, self._enum_class): + return val # Directly return the enum value if it is the enum. + + if val not in self._str2enum: + msg = "Not a valid enum value: '{}', valid values: {}" + raise ValidationException( + message=msg.format(val, ", ".join(self.enum)), + no_personal_data_message=msg.format("[val]", "[enum]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return self._str2enum[val] + + def _update_default(self, default_value: Any) -> None: + """Enum parameter support updating values with a string value. + + :param default_value: The default value for the input + :type default_value: Any + """ + enum_val = self._parse(default_value) + if self._enum_class and isinstance(enum_val, self._enum_class): + enum_val = enum_val.value + self.default = enum_val diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py new file mode 100644 index 00000000..8a4fe21f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/external_data.py @@ -0,0 +1,207 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from inspect import Parameter +from typing import Dict, List, Optional, Union + +from azure.ai.ml.constants._component import ExternalDataType +from azure.ai.ml.entities._inputs_outputs.utils import _remove_empty_values +from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin + + +class StoredProcedureParameter(DictMixin, RestTranslatableMixin): + """Define a stored procedure parameter class for DataTransfer import database task. + + :keyword name: The name of the database stored procedure. + :paramtype name: str + :keyword value: The value of the database stored procedure. + :paramtype value: str + :keyword type: The type of the database stored procedure. + :paramtype type: str + """ + + def __init__( + self, + *, + name: Optional[str] = None, + value: Optional[str] = None, + type: Optional[str] = None, # pylint: disable=redefined-builtin + ) -> None: + self.type = type + self.name = name + self.value = value + + +class Database(DictMixin, RestTranslatableMixin): + """Define a database class for a DataTransfer Component or Job. + + :keyword query: The SQL query to retrieve data from the database. + :paramtype query: str + :keyword table_name: The name of the database table. + :paramtype table_name: str + :keyword stored_procedure: The name of the stored procedure. + :paramtype stored_procedure: str + :keyword stored_procedure_params: The parameters for the stored procedure. + :paramtype stored_procedure_params: List + :keyword connection: The connection string for the database. + The credential information should be stored in the connection. + :paramtype connection: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the Database object cannot be successfully validated. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_input_output_configurations.py + :start-after: [START configure_database] + :end-before: [END configure_database] + :language: python + :dedent: 8 + :caption: Create a database and querying a database table. + """ + + _EMPTY = Parameter.empty + + def __init__( + self, + *, + query: Optional[str] = None, + table_name: Optional[str] = None, + stored_procedure: Optional[str] = None, + stored_procedure_params: Optional[List[Dict]] = None, + connection: Optional[str] = None, + ) -> None: + # As an annotation, it is not allowed to initialize the name. + # The name will be updated by the annotated variable name. + self.name = None + self.type = ExternalDataType.DATABASE + self.connection = connection + self.query = query + self.table_name = table_name + self.stored_procedure = stored_procedure + self.stored_procedure_params = stored_procedure_params + + def _to_dict(self, remove_name: bool = True) -> Dict: + """Convert the Source object to a dict. + + :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True. + :type remove_name: bool + :return: The dictionary representation of the class + :rtype: Dict + """ + keys = [ + "name", + "type", + "query", + "stored_procedure", + "stored_procedure_params", + "connection", + "table_name", + ] + if remove_name: + keys.remove("name") + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Source as component inputs, as for job input usage, + # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + return result + + def _update_name(self, name: str) -> None: + self.name = name + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Database": + return Database(**obj) + + @property + def stored_procedure_params(self) -> Optional[List]: + """Get or set the parameters for the stored procedure. + + :return: The parameters for the stored procedure. + :rtype: List[StoredProcedureParameter] + """ + + return self._stored_procedure_params + + @stored_procedure_params.setter + def stored_procedure_params(self, value: Union[Dict[str, str], List, None]) -> None: + """Set the parameters for the stored procedure. + + :param value: The parameters for the stored procedure. + :type value: Union[Dict[str, str], StoredProcedureParameter, None] + """ + if value is None: + self._stored_procedure_params = value + else: + if not isinstance(value, list): + value = [value] + for index, item in enumerate(value): + if isinstance(item, dict): + value[index] = StoredProcedureParameter(**item) + self._stored_procedure_params = value + + +class FileSystem(DictMixin, RestTranslatableMixin): + """Define a file system class of a DataTransfer Component or Job. + + e.g. source_s3 = FileSystem(path='s3://my_bucket/my_folder', connection='azureml:my_s3_connection') + + :param path: The path to which the input is pointing. Could be pointing to the path of file system. Default is None. + :type path: str + :param connection: Connection is workspace, we didn't support storage connection here, need leverage workspace + connection to store these credential info. Default is None. + :type connection: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Source cannot be successfully validated. + Details will be provided in the error message. + """ + + _EMPTY = Parameter.empty + + def __init__( + self, + *, + path: Optional[str] = None, + connection: Optional[str] = None, + ) -> None: + self.type = ExternalDataType.FILE_SYSTEM + self.name: Optional[str] = None + self.connection = connection + self.path: Optional[str] = None + + if path is not None and not isinstance(path, str): + # this logic will make dsl data binding expression working in the same way as yaml + # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here + self.path = str(path) + else: + self.path = path + + def _to_dict(self, remove_name: bool = True) -> Dict: + """Convert the Source object to a dict. + + :param remove_name: Whether to remove the `name` key from the dict representation. Defaults to True. + :type remove_name: bool + :return: The dictionary representation of the object + :rtype: Dict + """ + keys = ["name", "path", "type", "connection"] + if remove_name: + keys.remove("name") + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Source as component inputs, as for job input usage, + # rest object is generated by extracting Source's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + return result + + def _update_name(self, name: str) -> None: + self.name = name + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "FileSystem": + return FileSystem(**obj) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py new file mode 100644 index 00000000..e7fc565c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/group_input.py @@ -0,0 +1,251 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import copy +from enum import Enum as PyEnum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationException + +from .input import Input +from .output import Output +from .utils import is_group + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities._job.pipeline._io import _GroupAttrDict + + +class GroupInput(Input): + """Define a group input object. + + :param values: The values of the group input. + :type values: dict + :param _group_class: The class representing the group. + :type _group_class: Any + """ + + def __init__(self, values: dict, _group_class: Any) -> None: + super().__init__(type=IOConstants.GROUP_TYPE_NAME) + self.assert_group_value_valid(values) + self.values: Any = values + # Create empty default by values + # Note Output do not have default so just set a None + self.default = self._create_default() + # Save group class for init function generation + self._group_class = _group_class + + @classmethod + def _create_group_attr_dict(cls, dct: dict) -> "_GroupAttrDict": + from .._job.pipeline._io import _GroupAttrDict + + return _GroupAttrDict(dct) + + @classmethod + def _is_group_attr_dict(cls, obj: object) -> bool: + from .._job.pipeline._io import _GroupAttrDict + + return isinstance(obj, _GroupAttrDict) + + def __getattr__(self, item: Any) -> Any: + try: + # TODO: Bug Item number: 2883363 + return super().__getattr__(item) # type: ignore + except AttributeError: + # TODO: why values is not a dict in some cases? + if isinstance(self.values, dict) and item in self.values: + return self.values[item] + raise + + def _create_default(self) -> "_GroupAttrDict": + from .._job.pipeline._io import PipelineInput + + default_dict: dict = {} + # Note: no top-level group names at this time. + for k, v in self.values.items(): + # skip create default for outputs or port inputs + if isinstance(v, Output): + continue + + # Create PipelineInput object if not subgroup + if not isinstance(v, GroupInput): + default_dict[k] = PipelineInput(name=k, data=v.default, meta=v) + continue + # Copy and insert k into group names for subgroup + default_dict[k] = copy.deepcopy(v.default) + default_dict[k].insert_group_name_for_items(k) + return self._create_group_attr_dict(default_dict) + + @classmethod + def assert_group_value_valid(cls, values: Dict) -> None: + """Check if all values in the group are supported types. + + :param values: The values of the group. + :type values: dict + :raises ValueError: If a value in the group is not a supported type or if a parameter name is duplicated. + :raises UserErrorException: If a value in the group has an unsupported type. + """ + names = set() + msg = ( + f"Parameter {{!r}} with type {{!r}} is not supported in group. " + f"Supported types are: {list(IOConstants.INPUT_TYPE_COMBINATION.keys())}" + ) + for key, value in values.items(): + if not isinstance(value, (Input, Output)): + raise ValueError(msg.format(key, type(value).__name__)) + if value.type is None: + # Skip check for parameter translated from pipeline job (lost type) + continue + if value.type not in IOConstants.INPUT_TYPE_COMBINATION and not isinstance(value, GroupInput): + raise UserErrorException(msg.format(key, value.type)) + if key in names: + if not isinstance(value, Input): + raise ValueError(f"Duplicate parameter name {value.name!r} found in Group values.") + names.add(key) + + def flatten(self, group_parameter_name: str) -> Dict: + """Flatten the group and return all parameters. + + :param group_parameter_name: The name of the group parameter. + :type group_parameter_name: str + :return: A dictionary of flattened parameters. + :rtype: dict + """ + all_parameters = {} + group_parameter_name = group_parameter_name if group_parameter_name else "" + for key, value in self.values.items(): + flattened_name = ".".join([group_parameter_name, key]) + if isinstance(value, GroupInput): + all_parameters.update(value.flatten(flattened_name)) + else: + all_parameters[flattened_name] = value + return all_parameters + + def _to_dict(self) -> dict: + attr_dict = super()._to_dict() + attr_dict["values"] = {k: v._to_dict() for k, v in self.values.items()} # pylint: disable=protected-access + return attr_dict + + @staticmethod + def custom_class_value_to_attr_dict(value: Any, group_names: Optional[List] = None) -> Any: + """Convert a custom parameter group class object to GroupAttrDict. + + :param value: The value to convert. + :type value: any + :param group_names: The names of the parent groups. + :type group_names: list + :return: The converted value as a GroupAttrDict. + :rtype: GroupAttrDict or any + """ + if not is_group(value): + return value + group_definition = getattr(value, IOConstants.GROUP_ATTR_NAME) + group_names = [*group_names] if group_names else [] + attr_dict = {} + from .._job.pipeline._io import PipelineInput + + for k, v in value.__dict__.items(): + if is_group(v): + attr_dict[k] = GroupInput.custom_class_value_to_attr_dict(v, [*group_names, k]) + continue + data = v.value if isinstance(v, PyEnum) else v + if GroupInput._is_group_attr_dict(data): + attr_dict[k] = data + continue + attr_dict[k] = PipelineInput(name=k, meta=group_definition.get(k), data=data, group_names=group_names) + return GroupInput._create_group_attr_dict(attr_dict) + + @staticmethod + def validate_conflict_keys(keys: List) -> None: + """Validate conflicting keys in a flattened input dictionary, like {'a.b.c': 1, 'a.b': 1}. + + :param keys: The keys to validate. + :type keys: list + :raises ValidationException: If conflicting keys are found. + """ + conflict_msg = "Conflict parameter key '%s' and '%s'." + + def _group_count(s: str) -> int: + return len(s.split(".")) - 1 + + # Sort order by group numbers + keys = sorted(list(keys), key=_group_count) + for idx, key1 in enumerate(keys[:-1]): + for key2 in keys[idx + 1 :]: + if _group_count(key2) == 0: + continue + # Skip case a.b.c and a.b.c1 + if _group_count(key1) == _group_count(key2): + continue + if not key2.startswith(key1): + continue + # Invalid case 'a.b' in 'a.b.c' + raise ValidationException( + message=conflict_msg % (key1, key2), + no_personal_data_message=conflict_msg % ("[key1]", "[key2]"), + target=ErrorTarget.PIPELINE, + ) + + @staticmethod + def restore_flattened_inputs(inputs: Dict) -> Dict: + """Restore flattened inputs to structured groups. + + :param inputs: The flattened input dictionary. + :type inputs: dict + :return: The restored structured inputs. + :rtype: dict + """ + GroupInput.validate_conflict_keys(list(inputs.keys())) + restored_inputs = {} + group_inputs: Dict = {} + # 1. Build all group parameters dict + for name, data in inputs.items(): + # for a.b.c, group names is [a, b] + name_splits = name.split(".") + group_names, param_name = name_splits[:-1], name_splits[-1] + if not group_names: + restored_inputs[name] = data + continue + # change {'a.b.c': data} -> {'a': {'b': {'c': data}}} + target_dict = group_inputs + for group_name in group_names: + if group_name not in target_dict: + target_dict[group_name] = {} + target_dict = target_dict[group_name] + target_dict[param_name] = data + + def restore_from_dict_recursively(_data: dict) -> Union[GroupInput, "_GroupAttrDict"]: + for key, val in _data.items(): + if type(val) == dict: # pylint: disable=unidiomatic-typecheck + _data[key] = restore_from_dict_recursively(val) + # Create GroupInput for definition and _GroupAttrDict for PipelineInput + # Regard all Input class as parameter definition, as data will not appear in group now. + if all(isinstance(val, Input) for val in _data.values()): + return GroupInput(values=_data, _group_class=None) + return GroupInput._create_group_attr_dict(dct=_data) + + # 2. Rehydrate dict to GroupInput(definition) or GroupAttrDict. + for name, data in group_inputs.items(): + restored_inputs[name] = restore_from_dict_recursively(data) + return restored_inputs + + def _update_default(self, default_value: object = None) -> None: + default_cls = type(default_value) + + # Assert '__dsl_group__' must in the class of default value + if self._is_group_attr_dict(default_value): + self.default = default_value + self.optional = False + return + if default_value and not is_group(default_cls): + raise ValueError(f"Default value must be instance of parameter group, got {default_cls}.") + if hasattr(default_value, "__dict__"): + # Convert default value with customer type to _AttrDict + self.default = GroupInput.custom_class_value_to_attr_dict(default_value) + # Update item annotation + for key, annotation in self.values.items(): + if not hasattr(default_value, key): + continue + annotation._update_default(getattr(default_value, key)) # pylint: disable=protected-access + self.optional = default_value is None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py new file mode 100644 index 00000000..4a945108 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/input.py @@ -0,0 +1,547 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin +# disable redefined-builtin to use type/min/max as argument name + +import math +from inspect import Parameter +from typing import Any, Dict, List, Optional, Union, overload + +from typing_extensions import Literal + +from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + UserErrorException, + ValidationErrorType, + ValidationException, +) + +from .base import _InputOutputBase +from .utils import _get_param_with_standard_annotation, _remove_empty_values + + +class Input(_InputOutputBase): # pylint: disable=too-many-instance-attributes + """Initialize an Input object. + + :keyword type: The type of the data input. Accepted values are + 'uri_folder', 'uri_file', 'mltable', 'mlflow_model', 'custom_model', 'integer', 'number', 'string', and + 'boolean'. Defaults to 'uri_folder'. + :paramtype type: str + :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML asset + ID. + :paramtype path: Optional[str] + :keyword mode: The access mode of the data input. Accepted values are: + * 'ro_mount': Mount the data to the compute target as read-only, + * 'download': Download the data to the compute target, + * 'direct': Pass in the URI as a string to be accessed at runtime + :paramtype mode: Optional[str] + :keyword path_on_compute: The access path of the data input for compute + :paramtype path_on_compute: Optional[str] + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Union[int, float] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Union[int, float] + :keyword optional: Specifies if the input is optional. + :paramtype optional: Optional[bool] + :keyword description: Description of the input + :paramtype description: Optional[str] + :keyword datastore: The datastore to upload local files to. + :paramtype datastore: str + :keyword intellectual_property: Intellectual property for the input. + :paramtype intellectual_property: IntellectualProperty + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START create_inputs_outputs] + :end-before: [END create_inputs_outputs] + :language: python + :dedent: 8 + :caption: Creating a CommandJob with two inputs. + """ + + _EMPTY = Parameter.empty + _IO_KEYS = [ + "path", + "type", + "mode", + "path_on_compute", + "description", + "default", + "min", + "max", + "enum", + "optional", + "datastore", + ] + + @overload + def __init__( + self, + *, + type: str, + path: Optional[str] = None, + mode: Optional[str] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """""" + + @overload + def __init__( + self, + *, + type: Literal["number"] = "number", + default: Optional[float] = None, + min: Optional[float] = None, + max: Optional[float] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a number input. + + :keyword type: The type of the data input. Can only be set to "number". + :paramtype type: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Optional[float] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Optional[float] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + @overload + def __init__( + self, + *, + type: Literal["integer"] = "integer", + default: Optional[int] = None, + min: Optional[int] = None, + max: Optional[int] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize an integer input. + + :keyword type: The type of the data input. Can only be set to "integer". + :paramtype type: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword min: The minimum value for the input. If a value smaller than the minimum is passed to the job, the job + execution will fail. + :paramtype min: Optional[int] + :keyword max: The maximum value for the input. If a value larger than the maximum is passed to a job, the job + execution will fail. + :paramtype max: Optional[int] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + """ + + @overload + def __init__( + self, + *, + type: Literal["string"] = "string", + default: Optional[str] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + path: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a string input. + + :keyword type: The type of the data input. Can only be set to "string". + :paramtype type: str + :keyword default: The default value of this input. When a `default` is set, the input will be optional. + :paramtype default: str + :keyword optional: Determine if this input is optional. + :paramtype optional: bool + :keyword description: Description of the input. + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + @overload + def __init__( + self, + *, + type: Literal["boolean"] = "boolean", + default: Optional[bool] = None, + optional: Optional[bool] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a bool input. + + :keyword type: The type of the data input. Can only be set to "boolean". + :paramtype type: str + :keyword path: The path to the input data. Paths can be local paths, remote data uris, or a registered AzureML + asset id. + :paramtype path: str + :keyword default: The default value of the input. If a default is set, the input data will be optional. + :paramtype default: Union[str, int, float, bool] + :keyword optional: Specifies if the input is optional. + :paramtype optional: bool + :keyword description: Description of the input + :paramtype description: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Input cannot be successfully validated. + Details will be provided in the error message. + """ + + def __init__( + self, + *, + type: str = "uri_folder", + path: Optional[str] = None, + mode: Optional[str] = None, + path_on_compute: Optional[str] = None, + default: Optional[Union[str, int, float, bool]] = None, + optional: Optional[bool] = None, + min: Optional[Union[int, float]] = None, + max: Optional[Union[int, float]] = None, + enum: Any = None, + description: Optional[str] = None, + datastore: Optional[str] = None, + **kwargs: Any, + ) -> None: + super(Input, self).__init__(type=type) + # As an annotation, it is not allowed to initialize the _port_name. + self._port_name = None + self.description = description + self.path: Any = None + + if path is not None and not isinstance(path, str): + # this logic will make dsl data binding expression working in the same way as yaml + # it's written to handle InputOutputBase, but there will be loop import if we import InputOutputBase here + self.path = str(path) + else: + self.path = path + self.path_on_compute = path_on_compute + self.mode = None if self._is_primitive_type else mode + self._update_default(default) + self.optional = optional + # set the flag to mark if the optional=True is inferred by us. + self._is_inferred_optional = False + self.min = min + self.max = max + self.enum = enum + self.datastore = datastore + intellectual_property = kwargs.pop("intellectual_property", None) + if intellectual_property: + self._intellectual_property = ( + intellectual_property + if isinstance(intellectual_property, IntellectualProperty) + else IntellectualProperty(**intellectual_property) + ) + # normalize properties like ["default", "min", "max", "optional"] + self._normalize_self_properties() + + self._validate_parameter_combinations() + + @property + def _allowed_types(self) -> Any: + if self._multiple_types: + return None + return IOConstants.PRIMITIVE_STR_2_TYPE.get(self.type) + + @property + def _is_primitive_type(self) -> bool: + if self._multiple_types: + # note: we suppose that no primitive type will be included when there are multiple types + return False + return self.type in IOConstants.PRIMITIVE_STR_2_TYPE + + @property + def _multiple_types(self) -> bool: + """Returns True if this input has multiple types. + + Currently, there are two scenarios that need to check this property: + 1. before `in` as it may throw exception; there will be `in` operation for validation/transformation. + 2. `str()` of list is not ideal, so we need to manually create its string result. + + :return: Whether this input has multiple types + :rtype: bool + """ + return isinstance(self.type, list) + + def _is_literal(self) -> bool: + """Whether this input is a literal + + Override this function as `self.type` can be list and not hashable for operation `in`. + + :return: Whether is a literal + :rtype: bool + """ + return not self._multiple_types and super(Input, self)._is_literal() + + def _is_enum(self) -> bool: + """Whether input is an enum + + :return: True if the input is enum. + :rtype: bool + """ + res: bool = self.type == ComponentParameterTypes.STRING and self.enum + return res + + def _to_dict(self) -> Dict: + """Convert the Input object to a dict. + + :return: Dictionary representation of Input + :rtype: Dict + """ + keys = self._IO_KEYS + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _parse(self, val: Any) -> Union[int, float, bool, str, Any]: + """Parse value passed from command line. + + :param val: The input value + :type val: T + :return: The parsed value. + :rtype: Union[int, float, bool, str, T] + """ + if self.type == "integer": + return int(float(val)) # backend returns 10.0,for integer, parse it to float before int + if self.type == "number": + return float(val) + if self.type == "boolean": + lower_val = str(val).lower() + if lower_val not in {"true", "false"}: + msg = "Boolean parameter '{}' only accept True/False, got {}." + raise ValidationException( + message=msg.format(self._port_name, val), + no_personal_data_message=msg.format("[self._port_name]", "[val]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return lower_val == "true" + if self.type == "string": + return val if isinstance(val, str) else str(val) + return val + + def _parse_and_validate(self, val: Any) -> Union[int, float, bool, str, Any]: + """Parse the val passed from the command line and validate the value. + + :param val: The input string value from the command line. + :type val: T + :return: The parsed value, an exception will be raised if the value is invalid. + :rtype: Union[int, float, bool, str, T] + """ + if self._is_primitive_type: + val = self._parse(val) if isinstance(val, str) else val + self._validate_or_throw(val) + return val + + def _update_name(self, name: Any) -> None: + self._port_name = name + + def _update_default(self, default_value: Any) -> None: + """Update provided default values. + + :param default_value: The default value of the Input + :type default_value: Any + """ + name = "" if not self._port_name else f"{self._port_name!r} " + msg_prefix = f"Default value of Input {name}" + + if not self._is_primitive_type and default_value is not None: + msg = f"{msg_prefix}cannot be set: Non-primitive type Input has no default value." + raise UserErrorException(msg) + if isinstance(default_value, float) and not math.isfinite(default_value): + # Since nan/inf cannot be stored in the backend, just ignore them. + # logger.warning("Float default value %r is not allowed, ignored." % default_value) + return + # pylint: disable=pointless-string-statement + """Update provided default values. + Here we need to make sure the type of default value is allowed or it could be parsed.. + """ + if default_value is not None: + if type(default_value) not in IOConstants.PRIMITIVE_TYPE_2_STR: + msg = ( + f"{msg_prefix}cannot be set: type must be one of " + f"{list(IOConstants.PRIMITIVE_TYPE_2_STR.values())}, got '{type(default_value)}'." + ) + raise UserErrorException(msg) + + if not isinstance(default_value, self._allowed_types): + try: + default_value = self._parse(default_value) + # return original validation exception which is custom defined if raised by self._parse + except ValidationException as e: + raise e + except Exception as e: + msg = f"{msg_prefix}cannot be parsed, got '{default_value}', type = {type(default_value)!r}." + raise UserErrorException(msg) from e + self.default = default_value + + def _validate_or_throw(self, value: Any) -> None: + """Validate input parameter value, throw exception if not as expected. + + It will throw exception if validate failed, otherwise do nothing. + + :param value: A value to validate + :type value: Any + """ + if not self.optional and value is None: + msg = "Parameter {} cannot be None since it is not optional." + raise ValidationException( + message=msg.format(self._port_name), + no_personal_data_message=msg.format("[self._port_name]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if self._allowed_types and value is not None: + if not isinstance(value, self._allowed_types): + msg = "Unexpected data type for parameter '{}'. Expected {} but got {}." + raise ValidationException( + message=msg.format(self._port_name, self._allowed_types, type(value)), + no_personal_data_message=msg.format("[_port_name]", self._allowed_types, type(value)), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + # for numeric values, need extra check for min max value + if not self._multiple_types and self.type in ("integer", "number"): + if self.min is not None and value < self.min: + msg = "Parameter '{}' should not be less than {}." + raise ValidationException( + message=msg.format(self._port_name, self.min), + no_personal_data_message=msg.format("[_port_name]", self.min), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + if self.max is not None and value > self.max: + msg = "Parameter '{}' should not be greater than {}." + raise ValidationException( + message=msg.format(self._port_name, self.max), + no_personal_data_message=msg.format("[_port_name]", self.max), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _get_python_builtin_type_str(self) -> str: + """Get python builtin type for current input in string, eg: str. + + Return yaml type if not available. + + :return: The name of the input type + :rtype: str + """ + if self._multiple_types: + return "[" + ", ".join(self.type) + "]" + if self._is_primitive_type: + res_primitive_type: str = IOConstants.PRIMITIVE_STR_2_TYPE[self.type].__name__ + return res_primitive_type + res: str = self.type + return res + + def _validate_parameter_combinations(self) -> None: + """Validate different parameter combinations according to type.""" + parameters = ["type", "path", "mode", "default", "min", "max"] + parameters_dict: dict = {key: getattr(self, key, None) for key in parameters} + type = parameters_dict.pop("type") + + # validate parameter combination + if not self._multiple_types and type in IOConstants.INPUT_TYPE_COMBINATION: + valid_parameters = IOConstants.INPUT_TYPE_COMBINATION[type] + for key, value in parameters_dict.items(): + if key not in valid_parameters and value is not None: + msg = "Invalid parameter for '{}' Input, parameter '{}' should be None but got '{}'" + raise ValidationException( + message=msg.format(type, key, value), + no_personal_data_message=msg.format("[type]", "[parameter]", "[parameter_value]"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.PIPELINE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _simple_parse(self, value: Any, _type: Any = None) -> Any: + if self._multiple_types: + return value + if _type is None: + _type = self.type + if _type in IOConstants.PARAM_PARSERS: + return IOConstants.PARAM_PARSERS[_type](value) + return value + + def _normalize_self_properties(self) -> None: + # parse value from string to its original type. eg: "false" -> False + for key in ["min", "max"]: + if getattr(self, key) is not None: + origin_value = getattr(self, key) + new_value = self._simple_parse(origin_value) + setattr(self, key, new_value) + if self.optional: + self.optional = self._simple_parse(getattr(self, "optional", "false"), _type="boolean") + + @classmethod + def _get_input_by_type(cls, t: type, optional: Any = None) -> Optional["Input"]: + if t in IOConstants.PRIMITIVE_TYPE_2_STR: + return cls(type=IOConstants.PRIMITIVE_TYPE_2_STR[t], optional=optional) + return None + + @classmethod + def _get_default_unknown_input(cls, optional: Optional[bool] = None) -> "Input": + # Set type as None here to avoid schema validation failed + res: Input = cls(type=None, optional=optional) # type: ignore + return res + + @classmethod + def _get_param_with_standard_annotation(cls, func: Any) -> Dict: + return _get_param_with_standard_annotation(func, is_func=True) + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Input as component inputs, as for job input usage, + # rest object is generated by extracting Input's properties, see details in to_rest_dataset_literal_inputs() + result = self._to_dict() + # parse string -> String, integer -> Integer, etc. + if result["type"] in IOConstants.TYPE_MAPPING_YAML_2_REST: + result["type"] = IOConstants.TYPE_MAPPING_YAML_2_REST[result["type"]] + return result + + @classmethod + def _map_from_rest_type(cls, _type: Union[str, List]) -> Union[str, List]: + # this is for component rest object when using Input as component inputs + reversed_data_type_mapping = {v: k for k, v in IOConstants.TYPE_MAPPING_YAML_2_REST.items()} + # parse String -> string, Integer -> integer, etc + if not isinstance(_type, list) and _type in reversed_data_type_mapping: + res: str = reversed_data_type_mapping[_type] + return res + return _type + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Input": + obj["type"] = cls._map_from_rest_type(obj["type"]) + + return cls(**obj) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py new file mode 100644 index 00000000..1c4dcd06 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/output.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=redefined-builtin +import re +from typing import Any, Dict, Optional, overload + +from typing_extensions import Literal + +from azure.ai.ml.constants import AssetTypes +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty +from azure.ai.ml.exceptions import UserErrorException + +from .base import _InputOutputBase +from .utils import _remove_empty_values + + +class Output(_InputOutputBase): + _IO_KEYS = ["name", "version", "path", "path_on_compute", "type", "mode", "description", "early_available"] + + @overload + def __init__( + self, + *, + type: str, + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + **kwargs: Any, + ): ... + + @overload + def __init__( + self, + type: Literal["uri_file"] = "uri_file", + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + ): + """Define a URI file output. + + :keyword type: The type of the data output. Can only be set to 'uri_file'. + :paramtype type: str + :keyword path: The remote path where the output should be stored. + :paramtype path: str + :keyword mode: The access mode of the data output. Accepted values are + * 'rw_mount': Read-write mount the data, + * 'upload': Upload the data from the compute target, + * 'direct': Pass in the URI as a string + :paramtype mode: str + :keyword description: The description of the output. + :paramtype description: str + :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without + setting a version. + :paramtype name: str + :keyword version: The version used to register the output as a Data or Model asset. A version can be set only + when name is set. + :paramtype version: str + """ + + def __init__( # type: ignore[misc] + self, + *, + type: str = AssetTypes.URI_FOLDER, + path: Optional[str] = None, + mode: Optional[str] = None, + description: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Define an output. + + :keyword type: The type of the data output. Accepted values are 'uri_folder', 'uri_file', 'mltable', + 'mlflow_model', 'custom_model', and user-defined types. Defaults to 'uri_folder'. + :paramtype type: str + :keyword path: The remote path where the output should be stored. + :paramtype path: Optional[str] + :keyword mode: The access mode of the data output. Accepted values are + * 'rw_mount': Read-write mount the data + * 'upload': Upload the data from the compute target + * 'direct': Pass in the URI as a string + :paramtype mode: Optional[str] + :keyword path_on_compute: The access path of the data output for compute + :paramtype path_on_compute: Optional[str] + :keyword description: The description of the output. + :paramtype description: Optional[str] + :keyword name: The name to be used to register the output as a Data or Model asset. A name can be set without + setting a version. + :paramtype name: str + :keyword version: The version used to register the output as a Data or Model asset. A version can be set only + when name is set. + :paramtype version: str + :keyword is_control: Determine if the output is a control output. + :paramtype is_control: bool + :keyword early_available: Mark the output for early node orchestration. + :paramtype early_available: bool + :keyword intellectual_property: Intellectual property associated with the output. + It can be an instance of `IntellectualProperty` or a dictionary that will be used to create an instance. + :paramtype intellectual_property: Union[ + ~azure.ai.ml.entities._assets.intellectual_property.IntellectualProperty, dict] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START create_inputs_outputs] + :end-before: [END create_inputs_outputs] + :language: python + :dedent: 8 + :caption: Creating a CommandJob with a folder output. + """ + super(Output, self).__init__(type=type) + # As an annotation, it is not allowed to initialize the _port_name. + self._port_name = None + self.name = kwargs.pop("name", None) + self.version = kwargs.pop("version", None) + self._is_primitive_type = self.type in IOConstants.PRIMITIVE_STR_2_TYPE + self.description = description + self.path = path + self.path_on_compute = kwargs.pop("path_on_compute", None) + self.mode = mode + # use this field to mark Output for early node orchestrate, currently hide in kwargs + self.early_available = kwargs.pop("early_available", None) + self._intellectual_property = None + intellectual_property = kwargs.pop("intellectual_property", None) + if intellectual_property: + self._intellectual_property = ( + intellectual_property + if isinstance(intellectual_property, IntellectualProperty) + else IntellectualProperty(**intellectual_property) + ) + self._assert_name_and_version() + # normalize properties + self._normalize_self_properties() + + def _get_hint(self, new_line_style: bool = False) -> Optional[str]: + comment_str = self.description.replace('"', '\\"') if self.description else self.type + return '"""%s"""' % comment_str if comment_str and new_line_style else comment_str + + def _to_dict(self) -> Dict: + """Convert the Output object to a dict. + + :return: The dictionary representation of Output + :rtype: Dict + """ + keys = self._IO_KEYS + result = {key: getattr(self, key) for key in keys} + res: dict = _remove_empty_values(result) + return res + + def _to_rest_object(self) -> Dict: + # this is for component rest object when using Output as component outputs, as for job output usage, + # rest object is generated by extracting Output's properties, see details in to_rest_data_outputs() + return self._to_dict() + + def _simple_parse(self, value: Any, _type: Any = None) -> Any: + if _type is None: + _type = self.type + if _type in IOConstants.PARAM_PARSERS: + return IOConstants.PARAM_PARSERS[_type](value) + return value + + def _normalize_self_properties(self) -> None: + # parse value from string to its original type. eg: "false" -> False + if self.early_available: + self.early_available = self._simple_parse(getattr(self, "early_available", "false"), _type="boolean") + + @classmethod + def _from_rest_object(cls, obj: Dict) -> "Output": + # this is for component rest object when using Output as component outputs + return Output(**obj) + + def _assert_name_and_version(self) -> None: + if self.name and not (re.match("^[A-Za-z0-9_-]*$", self.name) and len(self.name) <= 255): + raise UserErrorException( + f"The output name {self.name} can only contain alphanumeric characters, dashes and underscores, " + f"with a limit of 255 characters." + ) + if self.version and not self.name: + raise UserErrorException("Output name is required when output version is specified.") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py new file mode 100644 index 00000000..bd752571 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_inputs_outputs/utils.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +# enable protected access for protected helper functions + +import copy +from collections import OrderedDict +from enum import Enum as PyEnum +from enum import EnumMeta +from inspect import Parameter, getmro, signature +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + +from typing_extensions import Annotated, Literal, TypeAlias + +from azure.ai.ml.constants._component import IOConstants +from azure.ai.ml.exceptions import UserErrorException + +# avoid circular import error +if TYPE_CHECKING: + from .input import Input + from .output import Output + +SUPPORTED_RETURN_TYPES_PRIMITIVE = list(IOConstants.PRIMITIVE_TYPE_2_STR.keys()) + +Annotation: TypeAlias = Union[str, Type, Annotated[Any, Any], None] # type: ignore + + +def is_group(obj: object) -> bool: + """Return True if obj is a group or an instance of a parameter group class. + + :param obj: The object to check. + :type obj: Any + :return: True if obj is a group or an instance, False otherwise. + :rtype: bool + """ + return hasattr(obj, IOConstants.GROUP_ATTR_NAME) + + +def _get_annotation_by_value(val: Any) -> Union["Input", Type["Input"]]: + # TODO: we'd better remove this potential recursive import + from .enum_input import EnumInput + from .input import Input + + annotation: Any = None + + def _is_dataset(data: Any) -> bool: + from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin + + DATASET_TYPES = JobIOMixin + return isinstance(data, DATASET_TYPES) + + if _is_dataset(val): + annotation = Input + elif val is Parameter.empty or val is None: + # If no default value or default is None, create val as the basic parameter type, + # it could be replaced using component parameter definition. + annotation = Input._get_default_unknown_input() + elif isinstance(val, PyEnum): + # Handle enum values + annotation = EnumInput(enum=val.__class__) + else: + _new_annotation = _get_annotation_cls_by_type(type(val), raise_error=False) + if not _new_annotation: + # Fall back to default + annotation = Input._get_default_unknown_input() + else: + return _new_annotation + return cast(Union["Input", Type["Input"]], annotation) + + +def _get_annotation_cls_by_type( + t: type, raise_error: bool = False, optional: Optional[bool] = None +) -> Optional["Input"]: + # TODO: we'd better remove this potential recursive import + from .input import Input + + cls = Input._get_input_by_type(t, optional=optional) + if cls is None and raise_error: + raise UserErrorException(f"Can't convert type {t} to azure.ai.ml.Input") + return cls + + +# pylint: disable=too-many-statements +def _get_param_with_standard_annotation( + cls_or_func: Union[Callable, Type], is_func: bool = False, skip_params: Optional[List[str]] = None +) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Standardize function parameters or class fields with dsl.types annotation. + + :param cls_or_func: Either a class or a function + :type cls_or_func: Union[Callable, Type] + :param is_func: Whether `cls_or_func` is a function. Defaults to False. + :type is_func: bool + :param skip_params: + :type skip_params: Optional[List[str]] + :return: A dictionary of field annotations + :rtype: Dict[str, Union[Annotation, "Input", "Output"]] + """ + # TODO: we'd better remove this potential recursive import + from .group_input import GroupInput + from .input import Input + from .output import Output + + def _is_dsl_type_cls(t: Any) -> bool: + if type(t) is not type: # pylint: disable=unidiomatic-typecheck + return False + return issubclass(t, (Input, Output)) + + def _is_dsl_types(o: object) -> bool: + return _is_dsl_type_cls(type(o)) + + def _get_fields(annotations: Dict) -> Dict: + """Return field names to annotations mapping in class. + + :param annotations: The annotations + :type annotations: Dict[str, Union[Annotation, Input, Output]] + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + annotation_fields = OrderedDict() + for name, annotation in annotations.items(): + # Skip return type + if name == "return": + continue + # Handle EnumMeta annotation + if isinstance(annotation, EnumMeta): + from .enum_input import EnumInput + + annotation = EnumInput(type="string", enum=annotation) + # Handle Group annotation + if is_group(annotation): + _deep_copy: GroupInput = copy.deepcopy(getattr(annotation, IOConstants.GROUP_ATTR_NAME)) + annotation = _deep_copy + # Try creating annotation by type when got like 'param: int' + if not _is_dsl_type_cls(annotation) and not _is_dsl_types(annotation): + origin_annotation = annotation + annotation = cast(Input, _get_annotation_cls_by_type(annotation, raise_error=False)) + if not annotation: + msg = f"Unsupported annotation type {origin_annotation!r} for parameter {name!r}." + raise UserErrorException(msg) + annotation_fields[name] = annotation + return annotation_fields + + def _merge_field_keys( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> List[str]: + """Merge field keys from annotations and cls dict to get all fields in class. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: The map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: A list of field keys + :rtype: List[str] + """ + anno_keys = list(annotation_fields.keys()) + dict_keys = defaults_dict.keys() + if not dict_keys: + return anno_keys + return [*anno_keys, *[key for key in dict_keys if key not in anno_keys]] + + def _update_annotation_with_default( + anno: Union[Annotation, Input, Output], name: str, default: Any + ) -> Union[Annotation, Input, Output]: + """Create annotation if is type class and update the default. + + :param anno: The annotation + :type anno: Union[Annotation, Input, Output] + :param name: The port name + :type name: str + :param default: The default value + :type default: Any + :return: The updated annotation + :rtype: Union[Annotation, Input, Output] + """ + # Create instance if is type class + complete_annotation = anno + if _is_dsl_type_cls(anno): + if anno is not None and not isinstance(anno, (str, Input, Output)): + complete_annotation = anno() + if complete_annotation is not None and not isinstance(complete_annotation, str): + complete_annotation._port_name = name + if default is Input._EMPTY: + return complete_annotation + if isinstance(complete_annotation, Input): + # Non-parameter Input has no default attribute + if complete_annotation._is_primitive_type and complete_annotation.default is not None: + # logger.warning( + # f"Warning: Default value of f{complete_annotation.name!r} is set twice: " + # f"{complete_annotation.default!r} and {default!r}, will use {default!r}" + # ) + pass + complete_annotation._update_default(default) + if isinstance(complete_annotation, Output) and default is not None: + msg = ( + f"Default value of Output {complete_annotation._port_name!r} cannot be set:" + f"Output has no default value." + ) + raise UserErrorException(msg) + return complete_annotation + + def _update_fields_with_default( + annotation_fields: Dict[str, Union[Annotation, Input, Output]], defaults_dict: Dict[str, Any] + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Use public values in class dict to update annotations. + + :param annotation_fields: The field annotations + :type annotation_fields: Dict[str, Union[Annotation, Input, Output]] + :param defaults_dict: A map of variable name to default value + :type defaults_dict: Dict[str, Any] + :return: List of field names + :rtype: List[str] + """ + all_fields = OrderedDict() + all_filed_keys = _merge_field_keys(annotation_fields, defaults_dict) + for name in all_filed_keys: + # Get or create annotation + annotation = ( + annotation_fields[name] + if name in annotation_fields + else _get_annotation_by_value(defaults_dict.get(name, Input._EMPTY)) + ) + # Create annotation if is class type and update default + annotation = _update_annotation_with_default(annotation, name, defaults_dict.get(name, Input._EMPTY)) + all_fields[name] = annotation + return all_fields + + def _merge_and_reorder( + inherited_fields: Dict[str, Union[Annotation, Input, Output]], + cls_fields: Dict[str, Union[Annotation, Input, Output]], + ) -> Dict[str, Union[Annotation, Input, Output]]: + """Merge inherited fields with cls fields. + + The order inside each part will not be changed. Order will be: + + {inherited_no_default_fields} + {cls_no_default_fields} + {inherited_default_fields} + {cls_default_fields}. + + + :param inherited_fields: The inherited fields + :type inherited_fields: Dict[str, Union[Annotation, Input, Output]] + :param cls_fields: The class fields + :type cls_fields: Dict[str, Union[Annotation, Input, Output]] + :return: The merged fields + :rtype: Dict[str, Union[Annotation, Input, Output]] + + .. admonition:: Additional Note + + :class: note + + If cls overwrite an inherited no default field with default, it will be put in the + cls_default_fields part and deleted from inherited_no_default_fields: + + .. code-block:: python + + @dsl.group + class SubGroup: + int_param0: Integer + int_param1: int + + @dsl.group + class Group(SubGroup): + int_param3: Integer + int_param1: int = 1 + + The init function of Group will be `def __init__(self, *, int_param0, int_param3, int_param1=1)`. + """ + + def _split( + _fields: Dict[str, Union[Annotation, Input, Output]] + ) -> Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]]: + """Split fields to two parts from the first default field. + + :param _fields: The fields + :type _fields: Dict[str, Union[Annotation, Input, Output]] + :return: A 2-tuple of (fields with no defaults, fields with defaults) + :rtype: Tuple[Dict[str, Union[Annotation, Input, Output]], Dict[str, Union[Annotation, Input, Output]]] + """ + _no_defaults_fields, _defaults_fields = {}, {} + seen_default = False + for key, val in _fields.items(): + if val is not None and not isinstance(val, str): + if val.get("default", None) or seen_default: + seen_default = True + _defaults_fields[key] = val + else: + _no_defaults_fields[key] = val + return _no_defaults_fields, _defaults_fields + + inherited_no_default, inherited_default = _split(inherited_fields) + cls_no_default, cls_default = _split(cls_fields) + # Cross comparison and delete from inherited_fields if same key appeared in cls_fields + # pylint: disable=consider-iterating-dictionary + for key in cls_default.keys(): + if key in inherited_no_default.keys(): + del inherited_no_default[key] + for key in cls_no_default.keys(): + if key in inherited_default.keys(): + del inherited_default[key] + return OrderedDict( + { + **inherited_no_default, + **cls_no_default, + **inherited_default, + **cls_default, + } + ) + + def _get_inherited_fields() -> Dict[str, Union[Annotation, Input, Output]]: + """Get all fields inherited from @group decorated base classes. + + :return: The field dict + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + # Return value of _get_param_with_standard_annotation + _fields: Dict[str, Union[Annotation, Input, Output]] = OrderedDict({}) + if is_func: + return _fields + # In reversed order so that more derived classes + # override earlier field definitions in base classes. + if isinstance(cls_or_func, type): + for base in cls_or_func.__mro__[-1:0:-1]: + if is_group(base): + # merge and reorder fields from current base with previous + _fields = _merge_and_reorder( + _fields, copy.deepcopy(getattr(base, IOConstants.GROUP_ATTR_NAME).values) + ) + return _fields + + skip_params = skip_params or [] + inherited_fields = _get_inherited_fields() + # From annotations get field with type + annotations: Dict[str, Annotation] = getattr(cls_or_func, "__annotations__", {}) + annotations = {k: v for k, v in annotations.items() if k not in skip_params} + annotations = _update_io_from_mldesigner(annotations) + annotation_fields = _get_fields(annotations) + defaults_dict: Dict[str, Any] = {} + # Update fields use class field with defaults from class dict or signature(func).paramters + if not is_func: + # Only consider public fields in class dict + defaults_dict = { + key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_") and key not in skip_params + } + else: + # Infer parameter type from value if is function + defaults_dict = { + key: val.default + for key, val in signature(cls_or_func).parameters.items() + if key not in skip_params and val.kind != val.VAR_KEYWORD + } + fields = _update_fields_with_default(annotation_fields, defaults_dict) + all_fields = _merge_and_reorder(inherited_fields, fields) + return all_fields + + +def _update_io_from_mldesigner(annotations: Dict[str, Annotation]) -> Dict[str, Union[Annotation, "Input", "Output"]]: + """Translates IOBase from mldesigner package to azure.ml.entities.Input/Output. + + This function depends on: + + * `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output instance annotations + to IO entities. + * class names of `mldesigner._input_output` to translate Input/Output class annotations + to IO entities. + + :param annotations: A map of variable names to annotations + :type annotations: Dict[str, Annotation] + :return: Dict with mldesigner IO types converted to azure-ai-ml Input/Output + :rtype: Dict[str, Union[Annotation, Input, Output]] + """ + from typing_extensions import get_args, get_origin + + from azure.ai.ml import Input, Output + + from .enum_input import EnumInput + + mldesigner_pkg = "mldesigner" + param_name = "_Param" + return_annotation_key = "return" + + def _is_primitive_type(io: type) -> bool: + """Checks whether type is a primitive type + + :param io: A type + :type io: type + :return: Return true if type is subclass of mldesigner._input_output._Param + :rtype: bool + """ + return any(io.__module__.startswith(mldesigner_pkg) and item.__name__ == param_name for item in getmro(io)) + + def _is_input_or_output_type(io: type, type_str: Literal["Input", "Output", "Meta"]) -> bool: + """Checks whether a type is an Input or Output type + + :param io: A type + :type io: type + :param type_str: The kind of type to check for + :type type_str: Literal["Input", "Output", "Meta"] + :return: Return true if type name contains type_str + :rtype: bool + """ + if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg): + if type_str in io.__name__: + return True + return False + + result = {} + for key, io in annotations.items(): # pylint: disable=too-many-nested-blocks + if isinstance(io, type): + if _is_input_or_output_type(io, "Input"): + # mldesigner.Input -> entities.Input + io = Input + elif _is_input_or_output_type(io, "Output"): + # mldesigner.Output -> entities.Output + io = Output + elif _is_primitive_type(io): + io = ( + Output(type=io.TYPE_NAME) # type: ignore + if key == return_annotation_key + else Input(type=io.TYPE_NAME) # type: ignore + ) + elif hasattr(io, "_to_io_entity_args_dict"): + try: + if _is_input_or_output_type(type(io), "Input"): + # mldesigner.Input() -> entities.Input() + if io is not None: + io = Input(**io._to_io_entity_args_dict()) + elif _is_input_or_output_type(type(io), "Output"): + # mldesigner.Output() -> entities.Output() + if io is not None: + io = Output(**io._to_io_entity_args_dict()) + elif _is_primitive_type(type(io)): + if io is not None and not isinstance(io, str): + if io._is_enum(): + io = EnumInput(**io._to_io_entity_args_dict()) + else: + io = ( + Output(**io._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**io._to_io_entity_args_dict()) + ) + except BaseException as e: + raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e + # Handle Annotated annotation + elif get_origin(io) is Annotated: + hint_type, arg, *hint_args = get_args(io) # pylint: disable=unused-variable + if hint_type in SUPPORTED_RETURN_TYPES_PRIMITIVE: + if not _is_input_or_output_type(type(arg), "Meta"): + raise UserErrorException( + f"Annotated Metadata class only support " + f"mldesigner._input_output.Meta, " + f"it is {type(arg)} now." + ) + if arg.type is not None and arg.type != hint_type: + raise UserErrorException( + f"Meta class type {arg.type} should be same as Annotated type: " f"{hint_type}" + ) + arg.type = hint_type + io = ( + Output(**arg._to_io_entity_args_dict()) + if key == return_annotation_key + else Input(**arg._to_io_entity_args_dict()) + ) + result[key] = io + return result + + +def _remove_empty_values(data: Any) -> Any: + """Recursively removes None values from a dict + + :param data: The value to remove None from + :type data: T + :return: + * `data` if `data` is not a dict + * `data` with None values recursively filtered out if data is a dict + :rtype: T + """ + if not isinstance(data, dict): + return data + return {k: _remove_empty_values(v) for k, v in data.items() if v is not None} |
