aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
from typing import Dict, List, Optional, Union

from marshmallow import INCLUDE, Schema

from ... import MpiDistribution, PyTorchDistribution, RayDistribution, TensorFlowDistribution
from ..._schema import PathAwareSchema
from ..._schema.core.fields import DistributionField
from ...entities import CommandJobLimits, JobResourceConfiguration
from ...entities._util import get_rest_dict_for_node_attrs
from .._schema.component import NodeType
from ..entities.component import InternalComponent
from ..entities.node import InternalBaseNode


class Command(InternalBaseNode):
    """Node of internal command components in pipeline with specific run settings.

    Different from azure.ai.ml.entities.Command, type of this class is CommandComponent.
    """

    def __init__(self, **kwargs):
        node_type = kwargs.pop("type", None) or NodeType.COMMAND
        super(Command, self).__init__(type=node_type, **kwargs)
        self._init = True
        self._resources = kwargs.pop("resources", JobResourceConfiguration())
        self._compute = kwargs.pop("compute", None)
        self._environment = kwargs.pop("environment", None)
        self._environment_variables = kwargs.pop("environment_variables", None)
        self._limits = kwargs.pop("limits", CommandJobLimits())
        self._init = False

    @property
    def compute(self) -> Optional[str]:
        """Get the compute definition for the command.

        :return: The compute definition
        :rtype: Optional[str]
        """
        return self._compute

    @compute.setter
    def compute(self, value: str) -> None:
        """Set the compute definition for the command.

        :param value: The new compute definition
        :type value: str
        """
        self._compute = value

    @property
    def environment(self) -> Optional[str]:
        """Get the environment definition for the command.

        :return: The environment definition
        :rtype: Optional[str]
        """
        return self._environment

    @environment.setter
    def environment(self, value: str) -> None:
        """Set the environment definition for the command.

        :param value: The new environment definition
        :type value: str
        """
        self._environment = value

    @property
    def environment_variables(self) -> Optional[Dict[str, str]]:
        """Get the environment variables for the command.

        :return: The environment variables
        :rtype: Optional[Dict[str, str]]
        """
        return self._environment_variables

    @environment_variables.setter
    def environment_variables(self, value: Dict[str, str]) -> None:
        """Set the environment variables for the command.

        :param value: The new environment variables
        :type value: Dict[str, str]
        """
        self._environment_variables = value

    @property
    def limits(self) -> CommandJobLimits:
        return self._limits

    @limits.setter
    def limits(self, value: CommandJobLimits):
        self._limits = value

    @property
    def resources(self) -> JobResourceConfiguration:
        """Compute Resource configuration for the component.

        :return: The resource configuration
        :rtype: JobResourceConfiguration
        """
        return self._resources

    @resources.setter
    def resources(self, value: JobResourceConfiguration):
        self._resources = value

    @classmethod
    def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
        return ["environment", "limits", "resources", "environment_variables"]

    @classmethod
    def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
        from .._schema.command import CommandSchema

        return CommandSchema(context=context)

    def _to_rest_object(self, **kwargs) -> dict:
        rest_obj = super()._to_rest_object(**kwargs)
        rest_obj.update(
            {
                "limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True),
                "resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True),
            }
        )
        return rest_obj

    @classmethod
    def _from_rest_object_to_init_params(cls, obj):
        obj = InternalBaseNode._from_rest_object_to_init_params(obj)

        if "resources" in obj and obj["resources"]:
            obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

        # handle limits
        if "limits" in obj and obj["limits"]:
            obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
        return obj


class Distributed(Command):
    def __init__(self, **kwargs):
        super(Distributed, self).__init__(**kwargs)
        self._distribution = kwargs.pop("distribution", None)
        self._type = NodeType.DISTRIBUTED
        if self._distribution is None:
            # hack: distribution.type is required to set distribution, which is defined in launcher.type
            if (
                isinstance(self.component, InternalComponent)
                and self.component.launcher
                and "type" in self.component.launcher
            ):
                self.distribution = {"type": self.component.launcher["type"]}
            else:
                raise ValueError(
                    "launcher.type must be specified in definition of DistributedComponent but got {}".format(
                        self.component
                    )
                )

    @property
    def distribution(
        self,
    ) -> Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution]:
        """The distribution config of component, e.g. distribution={'type': 'mpi'}.

        :return: The distribution config
        :rtype: Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution]
        """
        return self._distribution

    @distribution.setter
    def distribution(
        self,
        value: Union[Dict, PyTorchDistribution, TensorFlowDistribution, MpiDistribution, RayDistribution],
    ):
        if isinstance(value, dict):
            dist_schema = DistributionField(unknown=INCLUDE)
            value = dist_schema._deserialize(value=value, attr=None, data=None)
        self._distribution = value

    @classmethod
    def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
        from .._schema.command import DistributedSchema

        return DistributedSchema(context=context)

    @classmethod
    def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
        return Command._picked_fields_from_dict_to_rest_object() + ["distribution"]

    def _to_rest_object(self, **kwargs) -> dict:
        rest_obj = super()._to_rest_object(**kwargs)
        distribution = self.distribution._to_rest_object() if self.distribution else None  # pylint: disable=no-member
        rest_obj.update(
            {
                "distribution": get_rest_dict_for_node_attrs(distribution),
            }
        )
        return rest_obj