diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/command.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/command.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/command.py new file mode 100644 index 00000000..9ed732ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/entities/command.py @@ -0,0 +1,203 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +from typing import Dict, List, Optional, Union + +from marshmallow import INCLUDE, Schema + +from ... import MpiDistribution, PyTorchDistribution, RayDistribution, TensorFlowDistribution +from ..._schema import PathAwareSchema +from ..._schema.core.fields import DistributionField +from ...entities import CommandJobLimits, JobResourceConfiguration +from ...entities._util import get_rest_dict_for_node_attrs +from .._schema.component import NodeType +from ..entities.component import InternalComponent +from ..entities.node import InternalBaseNode + + +class Command(InternalBaseNode): + """Node of internal command components in pipeline with specific run settings. + + Different from azure.ai.ml.entities.Command, type of this class is CommandComponent. + """ + + def __init__(self, **kwargs): + node_type = kwargs.pop("type", None) or NodeType.COMMAND + super(Command, self).__init__(type=node_type, **kwargs) + self._init = True + self._resources = kwargs.pop("resources", JobResourceConfiguration()) + self._compute = kwargs.pop("compute", None) + self._environment = kwargs.pop("environment", None) + self._environment_variables = kwargs.pop("environment_variables", None) + self._limits = kwargs.pop("limits", CommandJobLimits()) + self._init = False + + @property + def compute(self) -> Optional[str]: + """Get the compute definition for the command. + + :return: The compute definition + :rtype: Optional[str] + """ + return self._compute + + @compute.setter + def compute(self, value: str) -> None: + """Set the compute definition for the command. + + :param value: The new compute definition + :type value: str + """ + self._compute = value + + @property + def environment(self) -> Optional[str]: + """Get the environment definition for the command. + + :return: The environment definition + :rtype: Optional[str] + """ + return self._environment + + @environment.setter + def environment(self, value: str) -> None: + """Set the environment definition for the command. + + :param value: The new environment definition + :type value: str + """ + self._environment = value + + @property + def environment_variables(self) -> Optional[Dict[str, str]]: + """Get the environment variables for the command. + + :return: The environment variables + :rtype: Optional[Dict[str, str]] + """ + return self._environment_variables + + @environment_variables.setter + def environment_variables(self, value: Dict[str, str]) -> None: + """Set the environment variables for the command. + + :param value: The new environment variables + :type value: Dict[str, str] + """ + self._environment_variables = value + + @property + def limits(self) -> CommandJobLimits: + return self._limits + + @limits.setter + def limits(self, value: CommandJobLimits): + self._limits = value + + @property + def resources(self) -> JobResourceConfiguration: + """Compute Resource configuration for the component. + + :return: The resource configuration + :rtype: JobResourceConfiguration + """ + return self._resources + + @resources.setter + def resources(self, value: JobResourceConfiguration): + self._resources = value + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return ["environment", "limits", "resources", "environment_variables"] + + @classmethod + def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]: + from .._schema.command import CommandSchema + + return CommandSchema(context=context) + + def _to_rest_object(self, **kwargs) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + rest_obj.update( + { + "limits": get_rest_dict_for_node_attrs(self.limits, clear_empty_value=True), + "resources": get_rest_dict_for_node_attrs(self.resources, clear_empty_value=True), + } + ) + return rest_obj + + @classmethod + def _from_rest_object_to_init_params(cls, obj): + obj = InternalBaseNode._from_rest_object_to_init_params(obj) + + if "resources" in obj and obj["resources"]: + obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"]) + + # handle limits + if "limits" in obj and obj["limits"]: + obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"]) + return obj + + +class Distributed(Command): + def __init__(self, **kwargs): + super(Distributed, self).__init__(**kwargs) + self._distribution = kwargs.pop("distribution", None) + self._type = NodeType.DISTRIBUTED + if self._distribution is None: + # hack: distribution.type is required to set distribution, which is defined in launcher.type + if ( + isinstance(self.component, InternalComponent) + and self.component.launcher + and "type" in self.component.launcher + ): + self.distribution = {"type": self.component.launcher["type"]} + else: + raise ValueError( + "launcher.type must be specified in definition of DistributedComponent but got {}".format( + self.component + ) + ) + + @property + def distribution( + self, + ) -> Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution]: + """The distribution config of component, e.g. distribution={'type': 'mpi'}. + + :return: The distribution config + :rtype: Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution, RayDistribution] + """ + return self._distribution + + @distribution.setter + def distribution( + self, + value: Union[Dict, PyTorchDistribution, TensorFlowDistribution, MpiDistribution, RayDistribution], + ): + if isinstance(value, dict): + dist_schema = DistributionField(unknown=INCLUDE) + value = dist_schema._deserialize(value=value, attr=None, data=None) + self._distribution = value + + @classmethod + def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]: + from .._schema.command import DistributedSchema + + return DistributedSchema(context=context) + + @classmethod + def _picked_fields_from_dict_to_rest_object(cls) -> List[str]: + return Command._picked_fields_from_dict_to_rest_object() + ["distribution"] + + def _to_rest_object(self, **kwargs) -> dict: + rest_obj = super()._to_rest_object(**kwargs) + distribution = self.distribution._to_rest_object() if self.distribution else None # pylint: disable=no-member + rest_obj.update( + { + "distribution": get_rest_dict_for_node_attrs(distribution), + } + ) + return rest_obj |