about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py
new file mode 100644
index 00000000..20c46169
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/dsl/_condition.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import List, Optional, Union
+
+from azure.ai.ml.entities._builders import BaseNode
+from azure.ai.ml.entities._builders.condition_node import ConditionNode
+from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
+from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
+from azure.ai.ml.exceptions import UserErrorException
+
+# pylint: disable=redefined-outer-name
+
+
+def condition(
+    condition: Union[str, bool, InputOutputBase, BaseNode, PipelineExpression],
+    *,
+    true_block: Optional[Union[BaseNode, List[BaseNode]]] = None,
+    false_block: Optional[Union[BaseNode, List[BaseNode]]] = None,
+) -> ConditionNode:
+    """Create a condition node to provide runtime condition graph experience.
+
+    Below is an example of using an expression result to control which step is executed.
+    If the pipeline parameter 'int_param1' is greater than 'int_param2', then 'true_step' will be executed,
+    otherwise, the 'false_step' will be executed.
+
+    .. code-block:: python
+
+        @dsl.pipeline
+        def pipeline_func(int_param1: int, int_param2: int):
+            true_step = component_func()
+            false_step = another_component_func()
+            dsl.condition(
+                int_param1 > int_param2,
+                true_block=true_step,
+                false_block=false_step
+            )
+
+    :param condition: The condition of the execution flow.
+        The value could be a boolean type control output, a node with exactly one boolean type output,
+        or a pipeline expression.
+    :type condition: Union[
+        str,
+        bool,
+        ~azure.ai.ml.entities._job.pipeline._io.InputOutputBase,
+        ~azure.ai.ml.entities._builders.BaseNode,
+        ~azure.ai.ml.entities._job.pipeline._pipeline_expression.PipelineExpression]
+    :keyword true_block: The block to be executed if the condition resolves to True.
+    :paramtype true_block: Union[
+        ~azure.ai.ml.entities._builders.BaseNode,
+        List[~azure.ai.ml.entities._builders.BaseNode]]
+    :keyword false_block: The block to be executed if the condition resolves to False.
+    :paramtype false_block: Union[
+        ~azure.ai.ml.entities._builders.BaseNode,
+        List[~azure.ai.ml.entities._builders.BaseNode]]
+    :return: The condition node.
+    :rtype: ConditionNode
+    :raises UserErrorException: Raised if the condition node has an incorrect number of outputs.
+    """
+    # resolve expression as command component
+    if isinstance(condition, PipelineExpression):
+        condition = condition.resolve()
+    if isinstance(condition, BaseNode):
+        if len(condition.outputs) != 1:
+            error_message = (
+                f"Exactly one output is expected for condition node, {len(condition.outputs)} outputs found."
+            )
+            raise UserErrorException(message=error_message, no_personal_data_message=error_message)
+        condition = list(condition.outputs.values())[0]
+    return ConditionNode(
+        condition=condition,
+        true_block=true_block,  # type: ignore[arg-type]
+        false_block=false_block,  # type: ignore[arg-type]
+        _from_component_func=True,
+    )