aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Dict, Optional, Union

from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._builders.do_while import DoWhile
from azure.ai.ml.entities._builders.pipeline import Pipeline
from azure.ai.ml.entities._inputs_outputs import Output
from azure.ai.ml.entities._job.pipeline._io import NodeOutput


def do_while(
    body: Union[Pipeline, Command], mapping: Dict, max_iteration_count: int, condition: Optional[Output] = None
) -> DoWhile:
    """Build a do_while node by specifying the loop body, output-input mapping, and termination condition.

    .. note::
        The following example shows how to use the `do_while` function to create a pipeline with a `do_while` node.

        .. code-block:: python

            from azure.ai.ml.dsl import pipeline
            from mldesigner.dsl import do_while

            @pipeline()
            def your_do_while_body():
                pass

            @pipeline()
            def pipeline_with_do_while_node():
                do_while_body = your_do_while_body()
                do_while_node = do_while(
                    body=do_while_body,
                    condition=do_while_body.outputs.condition_output,
                    mapping={
                        do_while_body.outputs.output1: do_while_body_inputs.input1,
                        do_while_body.outputs.output2: [
                            do_while_body_inputs.input2,
                            do_while_body_inputs.input3,
                        ],
                    },
                )
                # Connect to the do_while_node outputs
                component = component_func(
                    input1=do_while_body.outputs.output1, input2=do_while_body.outputs.output2
                )

    :param body: The pipeline job or command node for the do-while loop body.
    :type body: Union[~azure.ai.ml.entities._builders.pipeline.Pipeline, ~azure.ai.ml.entities._builders.Command]
    :param mapping: The output-input mapping for each round of the do-while loop.
        The key is the last round's output of the body, and the value is the input port for the current body.
    :type mapping: Dict[
        Union[str,  ~azure.ai.ml.entities.Output],
        Union[str, ~azure.ai.ml.entities.Input, List]]
    :param max_iteration_count: The limit on running the do-while node.
    :type max_iteration_count: int
    :param condition: The name of a boolean output of the body.
        The do-while loop stops if its value is evaluated to be negative.
        If not specified, it handles as a while-true loop.
    :type condition:  ~azure.ai.ml.entities.Output
    :return: The do-while node.
    :rtype: ~azure.ai.ml.entities._builders.do_while.DoWhile
    """
    do_while_node = DoWhile(
        body=body,
        condition=condition,  # type: ignore[arg-type]
        mapping=mapping,
        _from_component_func=True,
    )
    do_while_node.set_limits(max_iteration_count=max_iteration_count)

    def _infer_and_update_body_input_from_mapping() -> None:
        # pylint: disable=protected-access
        for source_output, body_input in mapping.items():
            # handle case that mapping key is a NodeOutput
            output_name = source_output._port_name if isinstance(source_output, NodeOutput) else source_output
            # if loop body output type is not specified, skip as we have no place to infer
            if body.outputs[output_name].type is None:
                continue
            # body input can be a list of inputs, normalize as a list to process
            if not isinstance(body_input, list):
                body_input = [body_input]
            for single_input in body_input:
                # if input type is specified, no need to infer and skip
                if single_input.type is not None:
                    continue
                inferred_type = body.outputs[output_name].type
                # update node input
                single_input._meta._is_inferred_optional = True
                single_input.type = inferred_type
                # update node corresponding component input
                input_name = single_input._meta.name
                body.component.inputs[input_name]._is_inferred_optional = True  # type: ignore[union-attr]
                body.component.inputs[input_name].type = inferred_type  # type: ignore[union-attr]

    # when mapping is a dictionary, infer and update for dynamic input
    if isinstance(mapping, dict):
        _infer_and_update_body_input_from_mapping()

    return do_while_node