aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
new file mode 100644
index 00000000..a1d2901c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py
@@ -0,0 +1,48 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields, post_dump, ValidationError
+
+from azure.ai.ml._schema import StringTransformedEnum
+from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField
+from azure.ai.ml._schema.pipeline.control_flow_job import ControlFlowSchema
+from azure.ai.ml.constants._component import ControlFlowType
+
+
+# ConditionNodeSchema did not inherit from BaseNodeSchema since it doesn't have inputs/outputs like other nodes.
+class ConditionNodeSchema(ControlFlowSchema):
+ type = StringTransformedEnum(allowed_values=[ControlFlowType.IF_ELSE])
+ condition = UnionField([DataBindingStr(), fields.Bool()])
+ true_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+ false_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())])
+
+ @post_dump
+ def simplify_blocks(self, data, **kwargs): # pylint: disable=unused-argument
+ # simplify true_block and false_block to single node if there is only one node in the list
+ # this is to make sure the request to backend won't change after we support list true/false blocks
+ block_keys = ["true_block", "false_block"]
+ for block in block_keys:
+ if isinstance(data.get(block), list) and len(data.get(block)) == 1:
+ data[block] = data.get(block)[0]
+
+ # validate blocks intersection
+ def _normalize_blocks(key):
+ blocks = data.get(key, [])
+ if blocks:
+ if not isinstance(blocks, list):
+ blocks = [blocks]
+ else:
+ blocks = []
+ return blocks
+
+ true_block = _normalize_blocks("true_block")
+ false_block = _normalize_blocks("false_block")
+
+ if not true_block and not false_block:
+ raise ValidationError("True block and false block cannot be empty at the same time.")
+
+ intersection = set(true_block).intersection(set(false_block))
+ if intersection:
+ raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}")
+
+ return data