# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from typing import List, Optional, Dict from azure.ai.ml.constants._common import AssetTypes from azure.ai.ml.entities._inputs_outputs import Input, Output from azure.ai.ml.entities._job.finetuning.custom_model_finetuning_job import ( CustomModelFineTuningJob, ) from azure.ai.ml.entities._job.job_resources import JobResources from azure.ai.ml.entities._job.queue_settings import QueueSettings from azure.ai.ml._utils._experimental import experimental @experimental def create_finetuning_job( *, model: str, task: str, training_data: str, output_model_name_prefix: str, validation_data: Optional[str] = None, hyperparameters: Optional[Dict[str, str]] = None, compute: Optional[str] = None, instance_types: Optional[List[str]] = None, job_tier: Optional[str] = None, **kwargs, ) -> CustomModelFineTuningJob: if not model: raise ValueError("model is required") if not task: raise ValueError("task is required") if not training_data: raise ValueError("training_data is required") if not output_model_name_prefix: raise ValueError("output_model_name_prefix is required") model_input = Input( type=AssetTypes.MLFLOW_MODEL, path=model, ) outputs = {"registered_model": Output(type="mlflow_model", name=output_model_name_prefix)} # For image tasks this would be mltable, check how to handle this training_data_input = Input( type=AssetTypes.URI_FILE, path=training_data, ) if validation_data: validation_data_input = Input( type=AssetTypes.URI_FILE, path=validation_data, ) job_resources = None if instance_types: job_resources = JobResources(instance_types=instance_types) queue_settings = None if job_tier: queue_settings = QueueSettings(job_tier=job_tier) custom_model_finetuning_job = CustomModelFineTuningJob( task=task, model=model_input, training_data=training_data_input, validation_data=validation_data_input, # pylint: disable=(possibly-used-before-assignment hyperparameters=hyperparameters, compute=compute, resources=job_resources, queue_settings=queue_settings, outputs=outputs, **kwargs, ) return custom_model_finetuning_job