aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import INCLUDE, fields, pre_dump

from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField
from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils.utils import is_private_preview_enabled
from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes
from azure.ai.ml.constants._component import ComponentParameterTypes

# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter
# because making those constants enum will fail in string serialization in marshmallow
asset_type_obj = AssetTypes()
SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [
    getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper()
]
param_obj = ComponentParameterTypes()
SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()]

input_output_type_obj = InputOutputModes()
# Link mode is only supported in component level currently
SUPPORTED_INPUT_OUTPUT_MODES = [
    getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper()
] + ["link"]


class InputPortSchema(metaclass=PatchedSchemaMeta):
    type = DumpableEnumField(
        allowed_values=SUPPORTED_PORT_TYPES,
        required=True,
    )
    description = fields.Str()
    optional = fields.Bool()
    default = fields.Str()
    mode = DumpableEnumField(
        allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
    )
    # hide in private preview
    if is_private_preview_enabled():
        # only protection_level is allowed for inputs
        intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))

    @pre_dump
    def add_private_fields_to_dump(self, data, **kwargs):  # pylint: disable=unused-argument
        # The ipp field is set on the output object as "_intellectual_property".
        # We need to set it as "intellectual_property" before dumping so that Marshmallow
        # can pick up the field correctly on dump and show it back to the user.
        if hasattr(data, "_intellectual_property"):
            ipp_field = data._intellectual_property  # pylint: disable=protected-access
            if ipp_field:
                setattr(data, "intellectual_property", ipp_field)
        return data


class OutputPortSchema(metaclass=PatchedSchemaMeta):
    type = DumpableEnumField(
        allowed_values=SUPPORTED_PORT_TYPES,
        required=True,
    )
    description = fields.Str()
    mode = DumpableEnumField(
        allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
    )
    # hide in private preview
    if is_private_preview_enabled():
        # only protection_level is allowed for outputs
        intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))

    @pre_dump
    def add_private_fields_to_dump(self, data, **kwargs):  # pylint: disable=unused-argument
        # The ipp field is set on the output object as "_intellectual_property".
        # We need to set it as "intellectual_property" before dumping so that Marshmallow
        # can pick up the field correctly on dump and show it back to the user.
        if hasattr(data, "_intellectual_property"):
            ipp_field = data._intellectual_property  # pylint: disable=protected-access
            if ipp_field:
                setattr(data, "intellectual_property", ipp_field)
        return data


class PrimitiveOutputSchema(OutputPortSchema):
    # Note: according to marshmallow doc on Handling Unknown Fields:
    # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields
    # specify unknown at instantiation time will not take effect;
    # still add here just for explicitly declare this behavior:
    # primitive type output used in environment that private preview flag is not enabled.
    class Meta:
        unknown = INCLUDE

    type = DumpableEnumField(
        allowed_values=SUPPORTED_PARAM_TYPES,
        required=True,
    )
    # hide early_available in spec
    if is_private_preview_enabled():
        early_available = fields.Bool()

    # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
    def _serialize(self, obj, *, many: bool = False):
        """Override to add private preview hidden fields

        :keyword many: Whether obj is a collection of objects.
        :paramtype many: bool
        """
        from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe

        ret = super()._serialize(obj, many=many)  # pylint: disable=no-member
        if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret:
            ret["early_available"] = obj.early_available
        return ret


class ParameterSchema(metaclass=PatchedSchemaMeta):
    type = DumpableEnumField(
        allowed_values=SUPPORTED_PARAM_TYPES,
        required=True,
    )
    optional = fields.Bool()
    default = UnionField([fields.Str(), fields.Number(), fields.Bool()])
    description = fields.Str()
    max = UnionField([fields.Str(), fields.Number()])
    min = UnionField([fields.Str(), fields.Number()])
    enum = fields.List(fields.Str())