# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=protected-access import copy import logging from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional, Union from azure.ai.ml._restclient.v2025_01_01_preview.models import CommandJob as RestCommandJob from azure.ai.ml._restclient.v2025_01_01_preview.models import JobBase from azure.ai.ml._schema.job.command_job import CommandJobSchema from azure.ai.ml._utils.utils import map_single_brackets_and_warn from azure.ai.ml.constants import JobType from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET, TYPE from azure.ai.ml.entities import Environment from azure.ai.ml.entities._credentials import ( AmlTokenConfiguration, ManagedIdentityConfiguration, UserIdentityConfiguration, _BaseJobIdentityConfiguration, ) from azure.ai.ml.entities._inputs_outputs import Input, Output from azure.ai.ml.entities._job._input_output_helpers import ( from_rest_data_outputs, from_rest_inputs_to_dataset_literal, to_rest_data_outputs, to_rest_dataset_literal_inputs, validate_inputs_for_command, ) from azure.ai.ml.entities._job.distribution import DistributionConfiguration from azure.ai.ml.entities._job.job_service import ( JobService, JobServiceBase, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService, ) from azure.ai.ml.entities._system_data import SystemData from azure.ai.ml.entities._util import load_from_dict from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException from .job import Job from .job_io_mixin import JobIOMixin from .job_limits import CommandJobLimits from .job_resource_configuration import JobResourceConfiguration from .parameterized_command import ParameterizedCommand from .queue_settings import QueueSettings # avoid circular import error if TYPE_CHECKING: from azure.ai.ml.entities import CommandComponent from azure.ai.ml.entities._builders import Command module_logger = logging.getLogger(__name__) class CommandJob(Job, ParameterizedCommand, JobIOMixin): """Command job. .. note:: For sweep jobs, inputs, outputs, and parameters are accessible as environment variables using the prefix ``AZUREML_PARAMETER_``. For example, if you have a parameter named "input_data", you can access it as ``AZUREML_PARAMETER_input_data``. :keyword services: Read-only information on services associated with the job. :paramtype services: Optional[dict[str, ~azure.ai.ml.entities.JobService]] :keyword inputs: Mapping of output data bindings used in the command. :paramtype inputs: Optional[dict[str, Union[~azure.ai.ml.Input, str, bool, int, float]]] :keyword outputs: Mapping of output data bindings used in the job. :paramtype outputs: Optional[dict[str, ~azure.ai.ml.Output]] :keyword identity: The identity that the job will use while running on compute. :paramtype identity: Optional[Union[~azure.ai.ml.ManagedIdentityConfiguration, ~azure.ai.ml.AmlTokenConfiguration, ~azure.ai.ml.UserIdentityConfiguration]] :keyword limits: The limits for the job. :paramtype limits: Optional[~azure.ai.ml.entities.CommandJobLimits] :keyword parent_job_name: parent job id for command job :paramtype parent_job_name: Optional[str] :keyword kwargs: A dictionary of additional configuration parameters. :paramtype kwargs: dict .. admonition:: Example: .. literalinclude:: ../samples/ml_samples_command_configurations.py :start-after: [START command_job_definition] :end-before: [END command_job_definition] :language: python :dedent: 8 :caption: Configuring a CommandJob. """ def __init__( self, *, inputs: Optional[Dict[str, Union[Input, str, bool, int, float]]] = None, outputs: Optional[Dict[str, Output]] = None, limits: Optional[CommandJobLimits] = None, identity: Optional[ Union[Dict, ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration] ] = None, services: Optional[ Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]] ] = None, parent_job_name: Optional[str] = None, **kwargs: Any, ) -> None: kwargs[TYPE] = JobType.COMMAND self._parameters: dict = kwargs.pop("parameters", {}) self.parent_job_name = parent_job_name super().__init__(**kwargs) self.outputs = outputs # type: ignore[assignment] self.inputs = inputs # type: ignore[assignment] self.limits = limits self.identity = identity self.services = services @property def parameters(self) -> Dict[str, str]: """MLFlow parameters. :return: MLFlow parameters logged in job. :rtype: dict[str, str] """ return self._parameters def _to_dict(self) -> Dict: res: dict = CommandJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) return res def _to_rest_object(self) -> JobBase: self._validate() self.command = map_single_brackets_and_warn(self.command) modified_properties = copy.deepcopy(self.properties) # Remove any properties set on the service as read-only modified_properties.pop("_azureml.ComputeTargetType", None) # Handle local compute case compute = self.compute resources = self.resources if self.compute == LOCAL_COMPUTE_TARGET: compute = None if resources is None: resources = JobResourceConfiguration() if not isinstance(resources, Dict): if resources.properties is None: resources.properties = {} # This is the format of the October Api response. We need to match it exactly resources.properties[LOCAL_COMPUTE_PROPERTY] = {LOCAL_COMPUTE_PROPERTY: True} properties = RestCommandJob( display_name=self.display_name, description=self.description, command=self.command, code_id=self.code, compute_id=compute, properties=modified_properties, experiment_name=self.experiment_name, inputs=to_rest_dataset_literal_inputs(self.inputs, job_type=self.type), outputs=to_rest_data_outputs(self.outputs), environment_id=self.environment, distribution=( self.distribution._to_rest_object() if self.distribution and not isinstance(self.distribution, Dict) else None ), tags=self.tags, identity=( self.identity._to_job_rest_object() if self.identity and not isinstance(self.identity, Dict) else None ), environment_variables=self.environment_variables, resources=resources._to_rest_object() if resources and not isinstance(resources, Dict) else None, limits=self.limits._to_rest_object() if self.limits else None, services=JobServiceBase._to_rest_job_services(self.services), queue_settings=self.queue_settings._to_rest_object() if self.queue_settings else None, parent_job_name=self.parent_job_name, ) result = JobBase(properties=properties) result.name = self.name return result @classmethod def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs: Any) -> "CommandJob": loaded_data = load_from_dict(CommandJobSchema, data, context, additional_message, **kwargs) return CommandJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) @classmethod def _load_from_rest(cls, obj: JobBase) -> "CommandJob": rest_command_job: RestCommandJob = obj.properties command_job = CommandJob( name=obj.name, id=obj.id, display_name=rest_command_job.display_name, description=rest_command_job.description, tags=rest_command_job.tags, properties=rest_command_job.properties, command=rest_command_job.command, experiment_name=rest_command_job.experiment_name, services=JobServiceBase._from_rest_job_services(rest_command_job.services), status=rest_command_job.status, creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None, code=rest_command_job.code_id, compute=rest_command_job.compute_id, environment=rest_command_job.environment_id, distribution=DistributionConfiguration._from_rest_object(rest_command_job.distribution), parameters=rest_command_job.parameters, # pylint: disable=protected-access identity=( _BaseJobIdentityConfiguration._from_rest_object(rest_command_job.identity) if rest_command_job.identity else None ), environment_variables=rest_command_job.environment_variables, resources=JobResourceConfiguration._from_rest_object(rest_command_job.resources), limits=CommandJobLimits._from_rest_object(rest_command_job.limits), inputs=from_rest_inputs_to_dataset_literal(rest_command_job.inputs), outputs=from_rest_data_outputs(rest_command_job.outputs), queue_settings=QueueSettings._from_rest_object(rest_command_job.queue_settings), parent_job_name=rest_command_job.parent_job_name, ) # Handle special case of local job if ( command_job.resources is not None and not isinstance(command_job.resources, Dict) and command_job.resources.properties is not None and command_job.resources.properties.get(LOCAL_COMPUTE_PROPERTY, None) ): command_job.compute = LOCAL_COMPUTE_TARGET command_job.resources.properties.pop(LOCAL_COMPUTE_PROPERTY) return command_job def _to_component(self, context: Optional[Dict] = None, **kwargs: Any) -> "CommandComponent": """Translate a command job to component. :param context: Context of command job YAML file. :type context: dict :return: Translated command component. :rtype: CommandComponent """ from azure.ai.ml.entities import CommandComponent pipeline_job_dict = kwargs.get("pipeline_job_dict", {}) context = context or {BASE_PATH_CONTEXT_KEY: Path("./")} # Create anonymous command component with default version as 1 return CommandComponent( tags=self.tags, is_anonymous=True, base_path=context[BASE_PATH_CONTEXT_KEY], code=self.code, command=self.command, environment=self.environment, description=self.description, inputs=self._to_inputs(inputs=self.inputs, pipeline_job_dict=pipeline_job_dict), outputs=self._to_outputs(outputs=self.outputs, pipeline_job_dict=pipeline_job_dict), resources=self.resources if self.resources else None, distribution=self.distribution if self.distribution else None, ) def _to_node(self, context: Optional[Dict] = None, **kwargs: Any) -> "Command": """Translate a command job to a pipeline node. :param context: Context of command job YAML file. :type context: dict :return: Translated command component. :rtype: Command """ from azure.ai.ml.entities._builders import Command component = self._to_component(context, **kwargs) return Command( component=component, compute=self.compute, # Need to supply the inputs with double curly. inputs=self.inputs, # type: ignore[arg-type] outputs=self.outputs, # type: ignore[arg-type] environment_variables=self.environment_variables, description=self.description, tags=self.tags, display_name=self.display_name, limits=self.limits, services=self.services, properties=self.properties, identity=self.identity, queue_settings=self.queue_settings, ) def _validate(self) -> None: if self.command is None: msg = "command is required" raise ValidationException( message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.MISSING_FIELD, ) if self.environment is None: msg = "environment is required for non-local runs" raise ValidationException( message=msg, no_personal_data_message=msg, target=ErrorTarget.JOB, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.MISSING_FIELD, ) if isinstance(self.environment, Environment): self.environment.validate() validate_inputs_for_command(self.command, self.inputs)