about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_builders/condition_node.py
blob: 5a5ad58ba8c11a5c20ff58a149dc9b2c98bc0468 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from typing import Any, Dict, List, Optional

from azure.ai.ml._schema import PathAwareSchema
from azure.ai.ml._utils.utils import is_data_binding_expression
from azure.ai.ml.constants._component import ControlFlowType
from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
from azure.ai.ml.entities._validation import MutableValidationResult


class ConditionNode(ControlFlowNode):
    """Conditional node in the pipeline.

    Please do not directly use this class.

    :param condition: The condition for the conditional node.
    :type condition: Any
    :param true_block: The list of nodes to execute when the condition is true.
    :type true_block: List[~azure.ai.ml.entities._builders.BaseNode]
    :param false_block: The list of nodes to execute when the condition is false.
    :type false_block: List[~azure.ai.ml.entities._builders.BaseNode]
    """

    def __init__(
        self, condition: Any, *, true_block: Optional[List] = None, false_block: Optional[List] = None, **kwargs: Any
    ) -> None:
        kwargs.pop("type", None)
        super(ConditionNode, self).__init__(type=ControlFlowType.IF_ELSE, **kwargs)
        self.condition = condition
        if true_block and not isinstance(true_block, list):
            true_block = [true_block]
        self._true_block = true_block
        if false_block and not isinstance(false_block, list):
            false_block = [false_block]
        self._false_block = false_block

    @classmethod
    def _create_schema_for_validation(cls, context: Any) -> PathAwareSchema:
        from azure.ai.ml._schema.pipeline.condition_node import ConditionNodeSchema

        return ConditionNodeSchema(context=context)

    @classmethod
    def _from_rest_object(cls, obj: dict) -> "ConditionNode":
        return cls(**obj)

    @classmethod
    def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ConditionNode":
        """Create a condition node instance from schema parsed dict.

        :param loaded_data: The loaded data
        :type loaded_data: Dict
        :return: The ConditionNode node
        :rtype: ConditionNode
        """
        return cls(**loaded_data)

    @property
    def true_block(self) -> Optional[List]:
        """Get the list of nodes to execute when the condition is true.

        :return: The list of nodes to execute when the condition is true.
        :rtype: List[~azure.ai.ml.entities._builders.BaseNode]
        """
        return self._true_block

    @property
    def false_block(self) -> Optional[List]:
        """Get the list of nodes to execute when the condition is false.

        :return: The list of nodes to execute when the condition is false.
        :rtype: List[~azure.ai.ml.entities._builders.BaseNode]
        """
        return self._false_block

    def _customized_validate(self) -> MutableValidationResult:
        return self._validate_params()

    def _validate_params(self) -> MutableValidationResult:
        # pylint disable=protected-access
        validation_result = self._create_empty_validation_result()
        if not isinstance(self.condition, (str, bool, InputOutputBase)):
            validation_result.append_error(
                yaml_path="condition",
                message=f"'condition' of dsl.condition node must be an instance of "
                f"{str}, {bool} or {InputOutputBase}, got {type(self.condition)}.",
            )

        # Check if output is a control output.
        # pylint: disable=protected-access
        if isinstance(self.condition, InputOutputBase) and self.condition._meta is not None:
            # pylint: disable=protected-access
            output_definition = self.condition._meta
            if output_definition is not None and not output_definition._is_primitive_type:
                validation_result.append_error(
                    yaml_path="condition",
                    message=f"'condition' of dsl.condition node must be primitive type "
                    f"with value 'True', got {output_definition._is_primitive_type}",
                )

        # check if condition is valid binding
        if isinstance(self.condition, str) and not is_data_binding_expression(
            self.condition, ["parent"], is_singular=False
        ):
            error_tail = "for example, ${{parent.jobs.xxx.outputs.output}}"
            validation_result.append_error(
                yaml_path="condition",
                message=f"'condition' of dsl.condition has invalid binding expression: {self.condition}, {error_tail}",
            )

        error_msg = (
            "{!r} of dsl.condition node must be an instance of " f"{BaseNode}, {AutoMLJob} or {str}," "got {!r}."
        )
        blocks = self.true_block if self.true_block else []
        for block in blocks:
            if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)):
                validation_result.append_error(
                    yaml_path="true_block", message=error_msg.format("true_block", type(block))
                )
        blocks = self.false_block if self.false_block else []
        for block in blocks:
            if block is not None and not isinstance(block, (BaseNode, AutoMLJob, str)):
                validation_result.append_error(
                    yaml_path="false_block", message=error_msg.format("false_block", type(block))
                )

        # check if true/false block is valid binding
        for name, blocks in {"true_block": self.true_block, "false_block": self.false_block}.items():  # type: ignore
            blocks = blocks if blocks else []
            for block in blocks:
                if block is None or not isinstance(block, str):
                    continue
                error_tail = "for example, ${{parent.jobs.xxx}}"
                if not is_data_binding_expression(block, ["parent", "jobs"], is_singular=False):
                    validation_result.append_error(
                        yaml_path=name,
                        message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}",
                    )

        return validation_result