diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py new file mode 100644 index 00000000..a1d2901c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/condition_node.py @@ -0,0 +1,48 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from marshmallow import fields, post_dump, ValidationError + +from azure.ai.ml._schema import StringTransformedEnum +from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField +from azure.ai.ml._schema.pipeline.control_flow_job import ControlFlowSchema +from azure.ai.ml.constants._component import ControlFlowType + + +# ConditionNodeSchema did not inherit from BaseNodeSchema since it doesn't have inputs/outputs like other nodes. +class ConditionNodeSchema(ControlFlowSchema): + type = StringTransformedEnum(allowed_values=[ControlFlowType.IF_ELSE]) + condition = UnionField([DataBindingStr(), fields.Bool()]) + true_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())]) + false_block = UnionField([NodeBindingStr(), fields.List(NodeBindingStr())]) + + @post_dump + def simplify_blocks(self, data, **kwargs): # pylint: disable=unused-argument + # simplify true_block and false_block to single node if there is only one node in the list + # this is to make sure the request to backend won't change after we support list true/false blocks + block_keys = ["true_block", "false_block"] + for block in block_keys: + if isinstance(data.get(block), list) and len(data.get(block)) == 1: + data[block] = data.get(block)[0] + + # validate blocks intersection + def _normalize_blocks(key): + blocks = data.get(key, []) + if blocks: + if not isinstance(blocks, list): + blocks = [blocks] + else: + blocks = [] + return blocks + + true_block = _normalize_blocks("true_block") + false_block = _normalize_blocks("false_block") + + if not true_block and not false_block: + raise ValidationError("True block and false block cannot be empty at the same time.") + + intersection = set(true_block).intersection(set(false_block)) + if intersection: + raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}") + + return data |