about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py244
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py119
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py96
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py78
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py66
6 files changed, 608 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py
new file mode 100644
index 00000000..fdf8caba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py
new file mode 100644
index 00000000..49b2c992
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_job.py
@@ -0,0 +1,244 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
+from azure.ai.ml._schema.job.parallel_job import ParallelJobSchema
+from azure.ai.ml._utils.utils import is_data_binding_expression
+from azure.ai.ml.constants import JobType
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
+from azure.ai.ml.entities._credentials import (
+    AmlTokenConfiguration,
+    ManagedIdentityConfiguration,
+    UserIdentityConfiguration,
+)
+from azure.ai.ml.entities._inputs_outputs import Input, Output
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+
+from ..job import Job
+from ..job_io_mixin import JobIOMixin
+from .parameterized_parallel import ParameterizedParallel
+
+# avoid circular import error
+if TYPE_CHECKING:
+    from azure.ai.ml.entities._builders import Parallel
+    from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+module_logger = logging.getLogger(__name__)
+
+
+class ParallelJob(Job, ParameterizedParallel, JobIOMixin):
+    """Parallel job.
+
+    :param name: Name of the job.
+    :type name: str
+    :param version: Version of the job.
+    :type version: str
+    :param id:  Global id of the resource, Azure Resource Manager ID.
+    :type id: str
+    :param type:  Type of the job, supported is 'parallel'.
+    :type type: str
+    :param description: Description of the job.
+    :type description: str
+    :param tags: Internal use only.
+    :type tags: dict
+    :param properties: Internal use only.
+    :type properties: dict
+    :param display_name: Display name of the job.
+    :type display_name: str
+    :param retry_settings: parallel job run failed retry
+    :type retry_settings: BatchRetrySettings
+    :param logging_level: A string of the logging level name
+    :type logging_level: str
+    :param max_concurrency_per_instance: The max parallellism that each compute instance has.
+    :type max_concurrency_per_instance: int
+    :param error_threshold: The number of item processing failures should be ignored.
+    :type error_threshold: int
+    :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored.
+    :type mini_batch_error_threshold: int
+    :keyword identity: The identity that the job will use while running on compute.
+    :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration,
+        ~azure.ai.ml.UserIdentityConfiguration]]
+    :param task: The parallel task.
+    :type task: ParallelTask
+    :param mini_batch_size: The mini batch size.
+    :type mini_batch_size: str
+    :param partition_keys: The partition keys.
+    :type partition_keys: list
+    :param input_data: The input data.
+    :type input_data: str
+    :param inputs: Inputs of the job.
+    :type inputs: dict
+    :param outputs: Outputs of the job.
+    :type outputs: dict
+    """
+
+    def __init__(
+        self,
+        *,
+        inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None,
+        outputs: Optional[Dict[str, Output]] = None,
+        identity: Optional[
+            Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration, Dict]
+        ] = None,
+        **kwargs: Any,
+    ):
+        kwargs[TYPE] = JobType.PARALLEL
+
+        super().__init__(**kwargs)
+
+        self.inputs = inputs  # type: ignore[assignment]
+        self.outputs = outputs  # type: ignore[assignment]
+        self.identity = identity
+
+    def _to_dict(self) -> Dict:
+        res: dict = ParallelJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+        return res
+
+    def _to_rest_object(self) -> None:
+        pass
+
+    @classmethod
+    def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "ParallelJob":
+        loaded_data = load_from_dict(ParallelJobSchema, data, context, additional_message, **kwargs)
+        return ParallelJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data)
+
+    @classmethod
+    def _load_from_rest(cls, obj: JobBaseData) -> None:
+        pass
+
+    def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "ParallelComponent":
+        """Translate a parallel job to component job.
+
+        :param context: Context of parallel job YAML file.
+        :type context: dict
+        :return: Translated parallel component.
+        :rtype: ParallelComponent
+        """
+        from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+        pipeline_job_dict = kwargs.get("pipeline_job_dict", {})
+        context = context or {BASE_PATH_CONTEXT_KEY: Path("./")}
+
+        # Create anonymous parallel component with default version as 1
+        init_kwargs = {}
+        for key in [
+            "mini_batch_size",
+            "partition_keys",
+            "logging_level",
+            "max_concurrency_per_instance",
+            "error_threshold",
+            "mini_batch_error_threshold",
+            "retry_settings",
+            "resources",
+        ]:
+            value = getattr(self, key)
+            from azure.ai.ml.entities import BatchRetrySettings, JobResourceConfiguration
+
+            values_to_check: List = []
+            if key == "retry_settings" and isinstance(value, BatchRetrySettings):
+                values_to_check = [value.max_retries, value.timeout]
+            elif key == "resources" and isinstance(value, JobResourceConfiguration):
+                values_to_check = [
+                    value.locations,
+                    value.instance_count,
+                    value.instance_type,
+                    value.shm_size,
+                    value.max_instance_count,
+                    value.docker_args,
+                ]
+            else:
+                values_to_check = [value]
+
+            # note that component level attributes can not be data binding expressions
+            # so filter out data binding expression properties here;
+            # they will still take effect at node level according to _to_node
+            if any(
+                map(
+                    lambda x: is_data_binding_expression(x, binding_prefix=["parent", "inputs"], is_singular=False)
+                    or is_data_binding_expression(x, binding_prefix=["inputs"], is_singular=False),
+                    values_to_check,
+                )
+            ):
+                continue
+
+            init_kwargs[key] = getattr(self, key)
+
+        return ParallelComponent(
+            base_path=context[BASE_PATH_CONTEXT_KEY],
+            # for parallel_job.task, all attributes for this are string for now so data binding expression is allowed
+            # in SDK level naturally, but not sure if such component is valid. leave the validation to service side.
+            task=self.task,
+            inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict),
+            outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict),
+            input_data=self.input_data,
+            # keep them if no data binding expression detected to keep the behavior of to_component
+            **init_kwargs,
+        )
+
+    def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Parallel":
+        """Translate a parallel job to a pipeline node.
+
+        :param context: Context of parallel job YAML file.
+        :type context: dict
+        :return: Translated parallel component.
+        :rtype: Parallel
+        """
+        from azure.ai.ml.entities._builders import Parallel
+
+        component = self._to_component(context, **kwargs)
+
+        return Parallel(
+            component=component,
+            compute=self.compute,
+            # Need to supply the inputs with double curly.
+            inputs=self.inputs,  # type: ignore[arg-type]
+            outputs=self.outputs,  # type: ignore[arg-type]
+            mini_batch_size=self.mini_batch_size,
+            partition_keys=self.partition_keys,
+            input_data=self.input_data,
+            # task will be inherited from component & base_path will be set correctly.
+            retry_settings=self.retry_settings,
+            logging_level=self.logging_level,
+            max_concurrency_per_instance=self.max_concurrency_per_instance,
+            error_threshold=self.error_threshold,
+            mini_batch_error_threshold=self.mini_batch_error_threshold,
+            environment_variables=self.environment_variables,
+            properties=self.properties,
+            identity=self.identity,
+            resources=self.resources if self.resources and not isinstance(self.resources, dict) else None,
+        )
+
+    def _validate(self) -> None:
+        if self.name is None:
+            msg = "Job name is required"
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.JOB,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.MISSING_FIELD,
+            )
+        if self.compute is None:
+            msg = "compute is required"
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.JOB,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.MISSING_FIELD,
+            )
+        if self.task is None:
+            msg = "task is required"
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.JOB,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.MISSING_FIELD,
+            )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py
new file mode 100644
index 00000000..7325aed3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parallel_task.py
@@ -0,0 +1,119 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+# from azure.ai.ml.entities._deployment.code_configuration import CodeConfiguration
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._assets.environment import Environment
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+
+class ParallelTask(RestTranslatableMixin, DictMixin):
+    """Parallel task.
+
+    :param type: The type of the parallel task.
+        Possible values are 'run_function'and 'model'.
+    :type type: str
+    :param code: A local or remote path pointing at source code.
+    :type code: str
+    :param entry_script: User script which will be run in parallel on multiple nodes. This is
+        specified as a local file path.
+        The entry_script should contain two functions:
+        ``init()``: this function should be used for any costly or common preparation for subsequent inferences,
+        e.g., deserializing and loading the model into a global object.
+        ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch.
+        'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an
+        argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset,
+        a Pandas DataFrame if input is a TabularDataset.
+        run() method should return a Pandas DataFrame or an array.
+        For append_row output_action, these returned elements are appended into the common output file.
+        For summary_only, the contents of the elements are ignored. For all output actions,
+        each returned output element indicates one successful inference of input element in the input mini-batch.
+        Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches
+        are processed.
+    :type entry_script: str
+    :param program_arguments: The arguments of the parallel task.
+    :type program_arguments: str
+    :param model: The model of the parallel task.
+    :type model: str
+    :param append_row_to: All values output by run() method invocations will be aggregated into
+        one unique file which is created in the output location.
+        if it is not set, 'summary_only' would invoked,  which means user script is expected to store the output itself.
+    :type append_row_to: str
+    :param environment: Environment that training job will run in.
+    :type environment: Union[Environment, str]
+    """
+
+    def __init__(
+        self,  # pylint: disable=unused-argument
+        *,
+        type: Optional[str] = None,  # pylint: disable=redefined-builtin
+        code: Optional[str] = None,
+        entry_script: Optional[str] = None,
+        program_arguments: Optional[str] = None,
+        model: Optional[str] = None,
+        append_row_to: Optional[str] = None,
+        environment: Optional[Union[Environment, str]] = None,
+        **kwargs: Any,
+    ):
+        self.type = type
+        self.code = code
+        self.entry_script = entry_script
+        self.program_arguments = program_arguments
+        self.model = model
+        self.append_row_to = append_row_to
+        self.environment: Any = environment
+
+    def _to_dict(self) -> Dict:
+        # pylint: disable=no-member
+        res: dict = ComponentParallelTaskSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
+        return res
+
+    @classmethod
+    def _load(
+        cls,  # pylint: disable=unused-argument
+        path: Optional[Union[PathLike, str]] = None,
+        params_override: Optional[list] = None,
+        **kwargs: Any,
+    ) -> "ParallelTask":
+        params_override = params_override or []
+        data = load_yaml(path)
+        return ParallelTask._load_from_dict(data=data, path=path, params_override=params_override)
+
+    @classmethod
+    def _load_from_dict(
+        cls,
+        data: dict,
+        path: Optional[Union[PathLike, str]] = None,
+        params_override: Optional[list] = None,
+        **kwargs: Any,
+    ) -> "ParallelTask":
+        params_override = params_override or []
+        context = {
+            BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(),
+            PARAMS_OVERRIDE_KEY: params_override,
+        }
+        res: ParallelTask = load_from_dict(ComponentParallelTaskSchema, data, context, **kwargs)
+        return res
+
+    @classmethod
+    def _from_dict(cls, dct: dict) -> "ParallelTask":
+        obj = cls(**dict(dct.items()))
+        return obj
+
+    def _validate(self) -> None:
+        if self.type is None:
+            msg = "'type' is required for ParallelTask {}."
+            raise ValidationException(
+                message=msg.format(self.type),
+                target=ErrorTarget.COMPONENT,
+                no_personal_data_message=msg.format(""),
+                error_category=ErrorCategory.USER_ERROR,
+            )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py
new file mode 100644
index 00000000..6b5dbced
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/parameterized_parallel.py
@@ -0,0 +1,96 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import Any, Dict, List, Optional, Union
+
+from ..job_resource_configuration import JobResourceConfiguration
+from .parallel_task import ParallelTask
+from .retry_settings import RetrySettings
+
+module_logger = logging.getLogger(__name__)
+
+
+class ParameterizedParallel:
+    """Parallel component that contains the traning parallel and supporting parameters for the parallel.
+
+    :param retry_settings: parallel component run failed retry
+    :type retry_settings: BatchRetrySettings
+    :param logging_level: A string of the logging level name
+    :type logging_level: str
+    :param max_concurrency_per_instance: The max parallellism that each compute instance has.
+    :type max_concurrency_per_instance: int
+    :param error_threshold: The number of item processing failures should be ignored.
+    :type error_threshold: int
+    :param mini_batch_error_threshold: The number of mini batch processing failures should be ignored.
+    :type mini_batch_error_threshold: int
+    :param task: The parallel task.
+    :type task: ParallelTask
+    :param mini_batch_size: The mini batch size.
+    :type mini_batch_size: str
+    :param input_data: The input data.
+    :type input_data: str
+    :param resources: Compute Resource configuration for the job.
+    :type resources: Union[Dict, ~azure.ai.ml.entities.JobResourceConfiguration]
+    """
+
+    # pylint: disable=too-many-instance-attributes
+    def __init__(
+        self,
+        retry_settings: Optional[RetrySettings] = None,
+        logging_level: Optional[str] = None,
+        max_concurrency_per_instance: Optional[int] = None,
+        error_threshold: Optional[int] = None,
+        mini_batch_error_threshold: Optional[int] = None,
+        input_data: Optional[str] = None,
+        task: Optional[ParallelTask] = None,
+        mini_batch_size: Optional[int] = None,
+        partition_keys: Optional[List] = None,
+        resources: Optional[Union[dict, JobResourceConfiguration]] = None,
+        environment_variables: Optional[Dict] = None,
+    ):
+        self.mini_batch_size = mini_batch_size
+        self.partition_keys = partition_keys
+        self.task = task
+        self.retry_settings = retry_settings
+        self.input_data = input_data
+        self.logging_level = logging_level
+        self.max_concurrency_per_instance = max_concurrency_per_instance
+        self.error_threshold = error_threshold
+        self.mini_batch_error_threshold = mini_batch_error_threshold
+        self.resources = resources
+        self.environment_variables = dict(environment_variables) if environment_variables else {}
+
+    @property
+    def task(self) -> Optional[ParallelTask]:
+        res: Optional[ParallelTask] = self._task
+        return res
+
+    @task.setter
+    def task(self, value: Any) -> None:
+        if isinstance(value, dict):
+            value = ParallelTask(**value)
+        self._task = value
+
+    @property
+    def resources(self) -> Optional[Union[dict, JobResourceConfiguration]]:
+        res: Optional[Union[dict, JobResourceConfiguration]] = self._resources
+        return res
+
+    @resources.setter
+    def resources(self, value: Any) -> None:
+        if isinstance(value, dict):
+            value = JobResourceConfiguration(**value)
+        self._resources = value
+
+    @property
+    def retry_settings(self) -> Optional[RetrySettings]:
+        res: Optional[RetrySettings] = self._retry_settings
+        return res
+
+    @retry_settings.setter
+    def retry_settings(self, value: Any) -> None:
+        if isinstance(value, dict):
+            value = RetrySettings(**value)
+        self._retry_settings = value
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py
new file mode 100644
index 00000000..2fb19ba1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/retry_settings.py
@@ -0,0 +1,78 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._utils.utils import load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY
+from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin
+from azure.ai.ml.entities._util import load_from_dict
+
+
+class RetrySettings(RestTranslatableMixin, DictMixin):
+    """Parallel RetrySettings.
+
+    :param timeout: Timeout in seconds for each invocation of the run() method.
+        (optional) This value could be set through PipelineParameter.
+    :type timeout: int
+    :param max_retries: The number of maximum tries for a failed or timeout mini batch.
+        The range is [1, int.max]. This value could be set through PipelineParameter.
+        A mini batch with dequeue count greater than this won't be processed again and will be deleted directly.
+    :type max_retries: int
+    """
+
+    def __init__(
+        self,  # pylint: disable=unused-argument
+        *,
+        timeout: Optional[Union[int, str]] = None,
+        max_retries: Optional[Union[int, str]] = None,
+        **kwargs: Any,
+    ):
+        self.timeout = timeout
+        self.max_retries = max_retries
+
+    def _to_dict(self) -> Dict:
+        res: dict = RetrySettingsSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)  # pylint: disable=no-member
+        return res
+
+    @classmethod
+    def _load(
+        cls,  # pylint: disable=unused-argument
+        path: Optional[Union[PathLike, str]] = None,
+        params_override: Optional[list] = None,
+        **kwargs: Any,
+    ) -> "RetrySettings":
+        params_override = params_override or []
+        data = load_yaml(path)
+        return RetrySettings._load_from_dict(data=data, path=path, params_override=params_override)
+
+    @classmethod
+    def _load_from_dict(
+        cls,
+        data: dict,
+        path: Optional[Union[PathLike, str]] = None,
+        params_override: Optional[list] = None,
+        **kwargs: Any,
+    ) -> "RetrySettings":
+        params_override = params_override or []
+        context = {
+            BASE_PATH_CONTEXT_KEY: Path(path).parent if path else Path.cwd(),
+            PARAMS_OVERRIDE_KEY: params_override,
+        }
+        res: RetrySettings = load_from_dict(RetrySettingsSchema, data, context, **kwargs)
+        return res
+
+    @classmethod
+    def _from_dict(cls, dct: dict) -> "RetrySettings":
+        obj = cls(**dict(dct.items()))
+        return obj
+
+    def _to_rest_object(self) -> Dict:
+        return self._to_dict()
+
+    @classmethod
+    def _from_rest_object(cls, obj: dict) -> "RetrySettings":
+        return cls._from_dict(obj)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py
new file mode 100644
index 00000000..180cee76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/parallel/run_function.py
@@ -0,0 +1,66 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+from typing import Any, Optional, Union
+
+from azure.ai.ml.constants import ParallelTaskType
+from azure.ai.ml.entities._assets.environment import Environment
+
+from .parallel_task import ParallelTask
+
+
+class RunFunction(ParallelTask):
+    """Run Function.
+
+    :param code: A local or remote path pointing at source code.
+    :type code: str
+    :param entry_script: User script which will be run in parallel on multiple nodes. This is
+        specified as a local file path.
+        The entry_script should contain two functions:
+        ``init()``: this function should be used for any costly or common preparation for subsequent inferences,
+        e.g., deserializing and loading the model into a global object.
+        ``run(mini_batch)``: The method to be parallelized. Each invocation will have one mini-batch.
+        'mini_batch': Batch inference will invoke run method and pass either a list or a Pandas DataFrame as an
+        argument to the method. Each entry in min_batch will be a filepath if input is a FileDataset,
+        a Pandas DataFrame if input is a TabularDataset.
+        run() method should return a Pandas DataFrame or an array.
+        For append_row output_action, these returned elements are appended into the common output file.
+        For summary_only, the contents of the elements are ignored. For all output actions,
+        each returned output element indicates one successful inference of input element in the input mini-batch.
+        Each parallel worker process will call `init` once and then loop over `run` function until all mini-batches
+        are processed.
+    :type entry_script: str
+    :param program_arguments: The arguments of the parallel task.
+    :type args: str
+    :param model: The model of the parallel task.
+    :type model: str
+    :param append_row_to: All values output by run() method invocations will be aggregated into
+        one unique file which is created in the output location.
+        if it is not set, 'summary_only' would invoked,  which means user script is expected to store the output itself.
+    :type append_row_to: str
+    :param environment: Environment that training job will run in.
+    :type environment: Union[Environment, str]
+    """
+
+    def __init__(
+        self,
+        *,
+        code: Optional[str] = None,
+        entry_script: Optional[str] = None,
+        program_arguments: Optional[str] = None,
+        model: Optional[str] = None,
+        append_row_to: Optional[str] = None,
+        environment: Optional[Union[Environment, str]] = None,
+        **kwargs: Any,
+    ):
+        super().__init__(
+            code=code,
+            entry_script=entry_script,
+            program_arguments=program_arguments,
+            model=model,
+            append_row_to=append_row_to,
+            environment=environment,
+            type=ParallelTaskType.RUN_FUNCTION,
+        )