diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py new file mode 100644 index 00000000..9fef9489 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py @@ -0,0 +1,126 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import INCLUDE, fields, pre_dump + +from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField +from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils.utils import is_private_preview_enabled +from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes +from azure.ai.ml.constants._component import ComponentParameterTypes + +# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter +# because making those constants enum will fail in string serialization in marshmallow +asset_type_obj = AssetTypes() +SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [ + getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper() +] +param_obj = ComponentParameterTypes() +SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()] + +input_output_type_obj = InputOutputModes() +# Link mode is only supported in component level currently +SUPPORTED_INPUT_OUTPUT_MODES = [ + getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper() +] + ["link"] + + +class InputPortSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PORT_TYPES, + required=True, + ) + description = fields.Str() + optional = fields.Bool() + default = fields.Str() + mode = DumpableEnumField( + allowed_values=SUPPORTED_INPUT_OUTPUT_MODES, + ) + # hide in private preview + if is_private_preview_enabled(): + # only protection_level is allowed for inputs + intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema)) + + @pre_dump + def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument + # The ipp field is set on the output object as "_intellectual_property". + # We need to set it as "intellectual_property" before dumping so that Marshmallow + # can pick up the field correctly on dump and show it back to the user. + if hasattr(data, "_intellectual_property"): + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + +class OutputPortSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PORT_TYPES, + required=True, + ) + description = fields.Str() + mode = DumpableEnumField( + allowed_values=SUPPORTED_INPUT_OUTPUT_MODES, + ) + # hide in private preview + if is_private_preview_enabled(): + # only protection_level is allowed for outputs + intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema)) + + @pre_dump + def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument + # The ipp field is set on the output object as "_intellectual_property". + # We need to set it as "intellectual_property" before dumping so that Marshmallow + # can pick up the field correctly on dump and show it back to the user. + if hasattr(data, "_intellectual_property"): + ipp_field = data._intellectual_property # pylint: disable=protected-access + if ipp_field: + setattr(data, "intellectual_property", ipp_field) + return data + + +class PrimitiveOutputSchema(OutputPortSchema): + # Note: according to marshmallow doc on Handling Unknown Fields: + # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields + # specify unknown at instantiation time will not take effect; + # still add here just for explicitly declare this behavior: + # primitive type output used in environment that private preview flag is not enabled. + class Meta: + unknown = INCLUDE + + type = DumpableEnumField( + allowed_values=SUPPORTED_PARAM_TYPES, + required=True, + ) + # hide early_available in spec + if is_private_preview_enabled(): + early_available = fields.Bool() + + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def _serialize(self, obj, *, many: bool = False): + """Override to add private preview hidden fields + + :keyword many: Whether obj is a collection of objects. + :paramtype many: bool + """ + from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe + + ret = super()._serialize(obj, many=many) # pylint: disable=no-member + if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret: + ret["early_available"] = obj.early_available + return ret + + +class ParameterSchema(metaclass=PatchedSchemaMeta): + type = DumpableEnumField( + allowed_values=SUPPORTED_PARAM_TYPES, + required=True, + ) + optional = fields.Bool() + default = UnionField([fields.Str(), fields.Number(), fields.Bool()]) + description = fields.Str() + max = UnionField([fields.Str(), fields.Number()]) + min = UnionField([fields.Str(), fields.Number()]) + enum = fields.List(fields.Str()) |