# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from marshmallow import INCLUDE, fields, pre_dump
from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField
from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils.utils import is_private_preview_enabled
from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes
from azure.ai.ml.constants._component import ComponentParameterTypes
# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter
# because making those constants enum will fail in string serialization in marshmallow
asset_type_obj = AssetTypes()
SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [
getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper()
]
param_obj = ComponentParameterTypes()
SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()]
input_output_type_obj = InputOutputModes()
# Link mode is only supported in component level currently
SUPPORTED_INPUT_OUTPUT_MODES = [
getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper()
] + ["link"]
class InputPortSchema(metaclass=PatchedSchemaMeta):
type = DumpableEnumField(
allowed_values=SUPPORTED_PORT_TYPES,
required=True,
)
description = fields.Str()
optional = fields.Bool()
default = fields.Str()
mode = DumpableEnumField(
allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
)
# hide in private preview
if is_private_preview_enabled():
# only protection_level is allowed for inputs
intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
@pre_dump
def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
# The ipp field is set on the output object as "_intellectual_property".
# We need to set it as "intellectual_property" before dumping so that Marshmallow
# can pick up the field correctly on dump and show it back to the user.
if hasattr(data, "_intellectual_property"):
ipp_field = data._intellectual_property # pylint: disable=protected-access
if ipp_field:
setattr(data, "intellectual_property", ipp_field)
return data
class OutputPortSchema(metaclass=PatchedSchemaMeta):
type = DumpableEnumField(
allowed_values=SUPPORTED_PORT_TYPES,
required=True,
)
description = fields.Str()
mode = DumpableEnumField(
allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
)
# hide in private preview
if is_private_preview_enabled():
# only protection_level is allowed for outputs
intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
@pre_dump
def add_private_fields_to_dump(self, data, **kwargs): # pylint: disable=unused-argument
# The ipp field is set on the output object as "_intellectual_property".
# We need to set it as "intellectual_property" before dumping so that Marshmallow
# can pick up the field correctly on dump and show it back to the user.
if hasattr(data, "_intellectual_property"):
ipp_field = data._intellectual_property # pylint: disable=protected-access
if ipp_field:
setattr(data, "intellectual_property", ipp_field)
return data
class PrimitiveOutputSchema(OutputPortSchema):
# Note: according to marshmallow doc on Handling Unknown Fields:
# https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields
# specify unknown at instantiation time will not take effect;
# still add here just for explicitly declare this behavior:
# primitive type output used in environment that private preview flag is not enabled.
class Meta:
unknown = INCLUDE
type = DumpableEnumField(
allowed_values=SUPPORTED_PARAM_TYPES,
required=True,
)
# hide early_available in spec
if is_private_preview_enabled():
early_available = fields.Bool()
# pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
def _serialize(self, obj, *, many: bool = False):
"""Override to add private preview hidden fields
:keyword many: Whether obj is a collection of objects.
:paramtype many: bool
"""
from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe
ret = super()._serialize(obj, many=many) # pylint: disable=no-member
if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret:
ret["early_available"] = obj.early_available
return ret
class ParameterSchema(metaclass=PatchedSchemaMeta):
type = DumpableEnumField(
allowed_values=SUPPORTED_PARAM_TYPES,
required=True,
)
optional = fields.Bool()
default = UnionField([fields.Str(), fields.Number(), fields.Bool()])
description = fields.Str()
max = UnionField([fields.Str(), fields.Number()])
min = UnionField([fields.Str(), fields.Number()])
enum = fields.List(fields.Str())