about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
new file mode 100644
index 00000000..a1d2901c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields, post_dump, ValidationError
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField
+from azure.ai.ml._schema.pipeline.control_flow_job import ControlFlowSchema
+from azure.ai.ml.constants._component import ControlFlowType
+
+
+# ConditionNodeSchema did not inherit from BaseNodeSchema since it doesn't have inputs/outputs like other nodes.
+class ConditionNodeSchema(ControlFlowSchema):
+    type = StringTransformedEnum(allowed_values=[ControlFlowType.IF_ELSE])
+    condition = UnionField([DataBindingStr(), fields.Bool()])
+    true_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+    false_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+
+    @post_dump
+    def simplify_blocks(self, data, **kwargs):  # pylint: disable=unused-argument
+        # simplify true_block and false_block to single node if there is only one node in the list
+        # this is to make sure the request to backend won't change after we support list true/false blocks
+        block_keys = ["true_block", "false_block"]
+        for block in block_keys:
+            if isinstance(data.get(block), list) and len(data.get(block)) == 1:
+                data[block] = data.get(block)[0]
+
+        # validate blocks intersection
+        def _normalize_blocks(key):
+            blocks = data.get(key, [])
+            if blocks:
+                if not isinstance(blocks, list):
+                    blocks = [blocks]
+            else:
+                blocks = []
+            return blocks
+
+        true_block = _normalize_blocks("true_block")
+        false_block = _normalize_blocks("false_block")
+
+        if not true_block and not false_block:
+            raise ValidationError("True block and false block cannot be empty at the same time.")
+
+        intersection = set(true_block).intersection(set(false_block))
+        if intersection:
+            raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}")
+
+        return data