about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py126
1 files changed, 126 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
new file mode 100644
index 00000000..9fef9489
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/input_output.py
@@ -0,0 +1,126 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, pre_dump
+
+from azure.ai.ml._schema.core.fields import DumpableEnumField, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.intellectual_property import ProtectionLevelSchema
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import AssetTypes, InputOutputModes, LegacyAssetTypes
+from azure.ai.ml.constants._component import ComponentParameterTypes
+
+# Here we use an adhoc way to collect all class constant attributes by checking if it's upper letter
+# because making those constants enum will fail in string serialization in marshmallow
+asset_type_obj = AssetTypes()
+SUPPORTED_PORT_TYPES = [LegacyAssetTypes.PATH] + [
+    getattr(asset_type_obj, k) for k in dir(asset_type_obj) if k.isupper()
+]
+param_obj = ComponentParameterTypes()
+SUPPORTED_PARAM_TYPES = [getattr(param_obj, k) for k in dir(param_obj) if k.isupper()]
+
+input_output_type_obj = InputOutputModes()
+# Link mode is only supported in component level currently
+SUPPORTED_INPUT_OUTPUT_MODES = [
+    getattr(input_output_type_obj, k) for k in dir(input_output_type_obj) if k.isupper()
+] + ["link"]
+
+
+class InputPortSchema(metaclass=PatchedSchemaMeta):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_PORT_TYPES,
+        required=True,
+    )
+    description = fields.Str()
+    optional = fields.Bool()
+    default = fields.Str()
+    mode = DumpableEnumField(
+        allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+    )
+    # hide in private preview
+    if is_private_preview_enabled():
+        # only protection_level is allowed for inputs
+        intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+    @pre_dump
+    def add_private_fields_to_dump(self, data, **kwargs):  # pylint: disable=unused-argument
+        # The ipp field is set on the output object as "_intellectual_property".
+        # We need to set it as "intellectual_property" before dumping so that Marshmallow
+        # can pick up the field correctly on dump and show it back to the user.
+        if hasattr(data, "_intellectual_property"):
+            ipp_field = data._intellectual_property  # pylint: disable=protected-access
+            if ipp_field:
+                setattr(data, "intellectual_property", ipp_field)
+        return data
+
+
+class OutputPortSchema(metaclass=PatchedSchemaMeta):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_PORT_TYPES,
+        required=True,
+    )
+    description = fields.Str()
+    mode = DumpableEnumField(
+        allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
+    )
+    # hide in private preview
+    if is_private_preview_enabled():
+        # only protection_level is allowed for outputs
+        intellectual_property = ExperimentalField(NestedField(ProtectionLevelSchema))
+
+    @pre_dump
+    def add_private_fields_to_dump(self, data, **kwargs):  # pylint: disable=unused-argument
+        # The ipp field is set on the output object as "_intellectual_property".
+        # We need to set it as "intellectual_property" before dumping so that Marshmallow
+        # can pick up the field correctly on dump and show it back to the user.
+        if hasattr(data, "_intellectual_property"):
+            ipp_field = data._intellectual_property  # pylint: disable=protected-access
+            if ipp_field:
+                setattr(data, "intellectual_property", ipp_field)
+        return data
+
+
+class PrimitiveOutputSchema(OutputPortSchema):
+    # Note: according to marshmallow doc on Handling Unknown Fields:
+    # https://marshmallow.readthedocs.io/en/stable/quickstart.html#handling-unknown-fields
+    # specify unknown at instantiation time will not take effect;
+    # still add here just for explicitly declare this behavior:
+    # primitive type output used in environment that private preview flag is not enabled.
+    class Meta:
+        unknown = INCLUDE
+
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_PARAM_TYPES,
+        required=True,
+    )
+    # hide early_available in spec
+    if is_private_preview_enabled():
+        early_available = fields.Bool()
+
+    # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+    def _serialize(self, obj, *, many: bool = False):
+        """Override to add private preview hidden fields
+
+        :keyword many: Whether obj is a collection of objects.
+        :paramtype many: bool
+        """
+        from azure.ai.ml.entities._job.pipeline._attr_dict import has_attr_safe
+
+        ret = super()._serialize(obj, many=many)  # pylint: disable=no-member
+        if has_attr_safe(obj, "early_available") and obj.early_available is not None and "early_available" not in ret:
+            ret["early_available"] = obj.early_available
+        return ret
+
+
+class ParameterSchema(metaclass=PatchedSchemaMeta):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_PARAM_TYPES,
+        required=True,
+    )
+    optional = fields.Bool()
+    default = UnionField([fields.Str(), fields.Number(), fields.Bool()])
+    description = fields.Str()
+    max = UnionField([fields.Str(), fields.Number()])
+    min = UnionField([fields.Str(), fields.Number()])
+    enum = fields.List(fields.Str())