# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=protected-access from typing import Dict, List, Optional, Union from marshmallow import INCLUDE, Schema from ... import MpiDistribution, PyTorchDistribution, RayDistribution, TensorFlowDistribution from ..._schema import PathAwareSchema from ..._schema.core.fields import DistributionField from ...entities import CommandJobLimits, JobResourceConfiguration from ...entities._util import get_rest_dict_for_node_attrs from .._schema.component import NodeType from ..entities.component import InternalComponent from ..entities.node import InternalBaseNode class Command(InternalBaseNode): """Node of internal command components in pipeline with specific run settings. Different from azure.ai.ml.entities.Command, type of this class is CommandComponent. """ def __init__(self, **kwargs): node_type = kwargs.pop("type", None) or NodeType.COMMAND super(Command, self).__init__(type=node_type, **kwargs) self._init = True self._resources = kwargs.pop("resources", JobResourceConfiguration()) self._compute = kwargs.pop("compute", None) self._environment = kwargs.pop("environment", None) self._environment_variables = kwargs.pop("environment_variables", None) self._limits = kwargs.pop("limits", CommandJobLimits()) self._init = False @property def compute(self) -> Optional[str]: """Get the compute definition for the command. :return: The compute definition :rtype: Optional[str] """ return self._compute @compute.setter def compute(self, value: str) -> None: """Set the compute definition for the command. :param value: The new compute definition :type value: str """ self._compute = value @property def environment(self) -> Optional[str]: """Get the environment definition for the command. :return: The environment definition :rtype: Optional[str] """ return self._environment @environment.setter def environment(self, value: str) -> None: """Set the environment definition for the command. :param value: The new environment definition :type value: str """ self._environment = value @property def environment_variables(self) -> Optional[Dict[str, str]]: """Get the environment variables for the command. :return: The environment variables :rtype: Optional[Dict[str, str]] """ return self._environment_variables @environment_variables.setter def environment_variables(self, value: Dict[str, str]) -> None: """Set the environment variables for the command. :param value: The new environment variables :type value: Dict[str, str] """ self._environment_variables = value @property def limits(self) -> CommandJobLimits: return self._limits @limits.setter def limits(self, value: CommandJobLimits): self._limits = value @property def resources(self) -> JobResourceConfiguration: """Compute Resource configuration for the component. :return: The resource configuration :rtype: JobResourceConfiguration """ return self._resources @resources.setter def resources(self, value: JobResourceConfiguration): self._resources = value @classmethod def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: return ["environment", "limits", "resources", "environment_variables"] @classmethod def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]: from .._schema.command import CommandSchema return CommandSchema(context=context) def _to_rest_object(self, **kwargs) -> dict: rest_obj = super()._to_rest_object(**kwargs) rest_obj.update( { "limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True), "resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True), } ) return rest_obj @classmethod def _from_rest_object_to_init_params(cls, obj): obj = InternalBaseNode._from_rest_object_to_init_params(obj) if "resources" in obj and obj["resources"]: obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"]) # handle limits if "limits" in obj and obj["limits"]: obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"]) return obj class Distributed(Command): def __init__(self, **kwargs): super(Distributed, self).__init__(**kwargs) self._distribution = kwargs.pop("distribution", None) self._type = NodeType.DISTRIBUTED if self._distribution is None: # hack: distribution.type is required to set distribution, which is defined in launcher.type if ( isinstance(self.component, InternalComponent) and self.component.launcher and "type" in self.component.launcher ): self.distribution = {"type": self.component.launcher["type"]} else: raise ValueError( "launcher.type must be specified in definition of DistributedComponent but got {}".format( self.component ) ) @property def distribution( self, ) -> Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution]: """The distribution config of component, e.g. distribution={'type': 'mpi'}. :return: The distribution config :rtype: Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution] """ return self._distribution @distribution.setter def distribution( self, value: Union[Dict, PyTorchDistribution, TensorFlowDistribution, MpiDistribution, RayDistribution], ): if isinstance(value, dict): dist_schema = DistributionField(unknown=INCLUDE) value = dist_schema._deserialize(value=value, attr=None, data=None) self._distribution = value @classmethod def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]: from .._schema.command import DistributedSchema return DistributedSchema(context=context) @classmethod def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: return Command._picked_fields_from_dict_to_rest_object() + ["distribution"] def _to_rest_object(self, **kwargs) -> dict: rest_obj = super()._to_rest_object(**kwargs) distribution = self.distribution._to_rest_object() if self.distribution else None # pylint: disable=no-member rest_obj.update( { "distribution": get_rest_dict_for_node_attrs(distribution), } ) return rest_obj