diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py | 314 |
1 files changed, 314 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py new file mode 100644 index 00000000..0a0c7e82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/command_job.py @@ -0,0 +1,314 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import copy +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob +from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase +from azure.ai.ml._schema.job.command_job import CommandJobSchema +from azure.ai.ml._utils.utils import map_single_brackets_and_warn +from azure.ai.ml.constants import JobType +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET, TYPE +from azure.ai.ml.entities import Environment +from azure.ai.ml.entities._credentials import ( + AmlTokenConfiguration, + ManagedIdentityConfiguration, + UserIdentityConfiguration, + _BaseJobIdentityConfiguration, +) +from azure.ai.ml.entities._inputs_outputs import Input, Output +from azure.ai.ml.entities._job._input_output_helpers import ( + from_rest_data_outputs, + from_rest_inputs_to_dataset_literal, + to_rest_data_outputs, + to_rest_dataset_literal_inputs, + validate_inputs_for_command, +) +from azure.ai.ml.entities._job.distribution import DistributionConfiguration +from azure.ai.ml.entities._job.job_service import ( + JobService, + JobServiceBase, + JupyterLabJobService, + SshJobService, + TensorBoardJobService, + VsCodeJobService, +) +from azure.ai.ml.entities._system_data import SystemData +from azure.ai.ml.entities._util import load_from_dict +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException + +from .job import Job +from .job_io_mixin import JobIOMixin +from .job_limits import CommandJobLimits +from .job_resource_configuration import JobResourceConfiguration +from .parameterized_command import ParameterizedCommand +from .queue_settings import QueueSettings + +# avoid circular import error +if TYPE_CHECKING: + from azure.ai.ml.entities import CommandComponent + from azure.ai.ml.entities._builders import Command + +module_logger = logging.getLogger(__name__) + + +class CommandJob(Job, ParameterizedCommand, JobIOMixin): + """Command job. + + .. note:: + For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix + ``AZUREML_PARAMETER_``. For example, if you have a parameter named "input_data", you can access it as + ``AZUREML_PARAMETER_input_data``. + + :keyword services: Read-only information on services associated with the job. + :paramtype services: Optional[dict[str, ~azure.ai.ml.entities.JobService]] + :keyword inputs: Mapping of output data bindings used in the command. + :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float]]] + :keyword outputs: Mapping of output data bindings used in the job. + :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]] + :keyword identity: The identity that the job will use while running on compute. + :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, + ~azure.ai.ml.UserIdentityConfiguration]] + :keyword limits: The limits for the job. + :paramtype limits: Optional[~azure.ai.ml.entities.CommandJobLimits] + :keyword parent_job_name: parent job id for command job + :paramtype parent_job_name: Optional[str] + :keyword kwargs: A dictionary of additional configuration parameters. + :paramtype kwargs: dict + + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_command_configurations.py + :start-after: [START command_job_definition] + :end-before: [END command_job_definition] + :language: python + :dedent: 8 + :caption: Configuring a CommandJob. + """ + + def __init__( + self, + *, + inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, + outputs: Optional[Dict[str, Output]] = None, + limits: Optional[CommandJobLimits] = None, + identity: Optional[ + Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] + ] = None, + services: Optional[ + Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] + ] = None, + parent_job_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + kwargs[TYPE] = JobType.COMMAND + self._parameters: dict = kwargs.pop("parameters", {}) + self.parent_job_name = parent_job_name + + super().__init__(**kwargs) + + self.outputs = outputs # type: ignore[assignment] + self.inputs = inputs # type: ignore[assignment] + self.limits = limits + self.identity = identity + self.services = services + + @property + def parameters(self) -> Dict[str, str]: + """MLFlow parameters. + + :return: MLFlow parameters logged in job. + :rtype: dict[str, str] + """ + return self._parameters + + def _to_dict(self) -> Dict: + res: dict = CommandJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) + return res + + def _to_rest_object(self) -> JobBase: + self._validate() + self.command = map_single_brackets_and_warn(self.command) + modified_properties = copy.deepcopy(self.properties) + # Remove any properties set on the service as read-only + modified_properties.pop("_azureml.ComputeTargetType", None) + # Handle local compute case + compute = self.compute + resources = self.resources + if self.compute == LOCAL_COMPUTE_TARGET: + compute = None + if resources is None: + resources = JobResourceConfiguration() + if not isinstance(resources, Dict): + if resources.properties is None: + resources.properties = {} + # This is the format of the October Api response. We need to match it exactly + resources.properties[LOCAL_COMPUTE_PROPERTY] = {LOCAL_COMPUTE_PROPERTY: True} + + properties = RestCommandJob( + display_name=self.display_name, + description=self.description, + command=self.command, + code_id=self.code, + compute_id=compute, + properties=modified_properties, + experiment_name=self.experiment_name, + inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type), + outputs=to_rest_data_outputs(self.outputs), + environment_id=self.environment, + distribution=( + self.distribution._to_rest_object() + if self.distribution and not isinstance(self.distribution, Dict) + else None + ), + tags=self.tags, + identity=( + self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, Dict) else None + ), + environment_variables=self.environment_variables, + resources=resources._to_rest_object() if resources and not isinstance(resources, Dict) else None, + limits=self.limits._to_rest_object() if self.limits else None, + services=JobServiceBase._to_rest_job_services(self.services), + queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None, + parent_job_name=self.parent_job_name, + ) + result = JobBase(properties=properties) + result.name = self.name + return result + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "CommandJob": + loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs) + return CommandJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + @classmethod + def _load_from_rest(cls, obj: JobBase) -> "CommandJob": + rest_command_job: RestCommandJob = obj.properties + command_job = CommandJob( + name=obj.name, + id=obj.id, + display_name=rest_command_job.display_name, + description=rest_command_job.description, + tags=rest_command_job.tags, + properties=rest_command_job.properties, + command=rest_command_job.command, + experiment_name=rest_command_job.experiment_name, + services=JobServiceBase._from_rest_job_services(rest_command_job.services), + status=rest_command_job.status, + creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, + code=rest_command_job.code_id, + compute=rest_command_job.compute_id, + environment=rest_command_job.environment_id, + distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution), + parameters=rest_command_job.parameters, + # pylint: disable=protected-access + identity=( + _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity) + if rest_command_job.identity + else None + ), + environment_variables=rest_command_job.environment_variables, + resources=JobResourceConfiguration._from_rest_object(rest_command_job.resources), + limits=CommandJobLimits._from_rest_object(rest_command_job.limits), + inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs), + outputs=from_rest_data_outputs(rest_command_job.outputs), + queue_settings=QueueSettings._from_rest_object(rest_command_job.queue_settings), + parent_job_name=rest_command_job.parent_job_name, + ) + # Handle special case of local job + if ( + command_job.resources is not None + and not isinstance(command_job.resources, Dict) + and command_job.resources.properties is not None + and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None) + ): + command_job.compute = LOCAL_COMPUTE_TARGET + command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY) + return command_job + + def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "CommandComponent": + """Translate a command job to component. + + :param context: Context of command job YAML file. + :type context: dict + :return: Translated command component. + :rtype: CommandComponent + """ + from azure.ai.ml.entities import CommandComponent + + pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) + context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} + + # Create anonymous command component with default version as 1 + return CommandComponent( + tags=self.tags, + is_anonymous=True, + base_path=context[BASE_PATH_CONTEXT_KEY], + code=self.code, + command=self.command, + environment=self.environment, + description=self.description, + inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), + outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), + resources=self.resources if self.resources else None, + distribution=self.distribution if self.distribution else None, + ) + + def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Command": + """Translate a command job to a pipeline node. + + :param context: Context of command job YAML file. + :type context: dict + :return: Translated command component. + :rtype: Command + """ + from azure.ai.ml.entities._builders import Command + + component = self._to_component(context, **kwargs) + + return Command( + component=component, + compute=self.compute, + # Need to supply the inputs with double curly. + inputs=self.inputs, # type: ignore[arg-type] + outputs=self.outputs, # type: ignore[arg-type] + environment_variables=self.environment_variables, + description=self.description, + tags=self.tags, + display_name=self.display_name, + limits=self.limits, + services=self.services, + properties=self.properties, + identity=self.identity, + queue_settings=self.queue_settings, + ) + + def _validate(self) -> None: + if self.command is None: + msg = "command is required" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if self.environment is None: + msg = "environment is required for non-local runs" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if isinstance(self.environment, Environment): + self.environment.validate() + validate_inputs_for_command(self.command, self.inputs) |