# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- """Entrypoint for creating Distillation task.""" from typing import Any, Dict, Optional, Union from azure.ai.ml._utils._experimental import experimental from azure.ai.ml.constants import DataGenerationType from azure.ai.ml.constants._common import AssetTypes from azure.ai.ml.entities._inputs_outputs import Input from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection @experimental def distillation( *, experiment_name: str, data_generation_type: str, data_generation_task_type: str, teacher_model_endpoint_connection: WorkspaceConnection, student_model: Union[Input, str], training_data: Optional[Union[Input, str]] = None, validation_data: Optional[Union[Input, str]] = None, teacher_model_settings: Optional[TeacherModelSettings] = None, prompt_settings: Optional[PromptSettings] = None, hyperparameters: Optional[Dict] = None, resources: Optional[ResourceConfiguration] = None, **kwargs: Any, ) -> "DistillationJob": """Function to create a Distillation job. A distillation job is used to transfer knowledge from a teacher model to student model by a two step process of generating synthetic data from the teacher model and then finetuning the student model with the generated synthetic data. :param experiment_name: The name of the experiment. :type experiment_name: str :param data_generation_type: The type of data generation to perform. Acceptable values: label_generation :type data_generation_type: str :param data_generation_task_type: The type of data to generate Acceptable values: NLI, NLU_QA, CONVERSATION, MATH, SUMMARIZATION :type: data_generation_task_type: str :param teacher_model_endpoint_connection: The kind of teacher model connection that includes the name, endpoint url, and api_key. :type: teacher_model_endpoint_connection: WorkspaceConnection :param student_model: The model to train :type student_model: typing.Union[Input, str] :param training_data: The training data to use. Should contain the questions but not the labels, defaults to None :type training_data: typing.Optional[typing.Union[Input, str]], optional :param validation_data: The validation data to use. Should contain the questions but not the labels, defaults to None :type validation_data: typing.Optional[typing.Union[Input, str]], optional :param teacher_model_settings: The settings for the teacher model. Accepts both the inference parameters and endpoint settings, defaults to None Acceptable keys for inference parameters: temperature, max_tokens, top_p, frequency_penalty, presence_penalty, stop :type teacher_model_settings: typing.Optional[TeacherModelSettings], optional :param prompt_settings: The settings for the prompt that affect the system prompt used for data generation, defaults to None :type prompt_settings: typing.Optional[PromptSettings], optional :param hyperparameters: The hyperparameters to use for finetuning, defaults to None :type hyperparameters: typing.Optional[typing.Dict], optional :param resources: The compute resource to use for the data generation step in the distillation job, defaults to None :type resources: typing.Optional[ResourceConfiguration], optional :raises ValueError: Raises ValueError if there is no training data and data generation type is 'label_generation' :return: A DistillationJob to submit :rtype: DistillationJob """ if isinstance(student_model, str): student_model = Input(type=AssetTypes.URI_FILE, path=student_model) if isinstance(training_data, str): training_data = Input(type=AssetTypes.URI_FILE, path=training_data) if isinstance(validation_data, str): validation_data = Input(type=AssetTypes.URI_FILE, path=validation_data) if training_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION: raise ValueError( f"Training data can not be None when data generation type is set to " f"{DataGenerationType.LABEL_GENERATION}." ) if validation_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION: raise ValueError( f"Validation data can not be None when data generation type is set to " f"{DataGenerationType.LABEL_GENERATION}." ) return DistillationJob( data_generation_type=data_generation_type, data_generation_task_type=data_generation_task_type, teacher_model_endpoint_connection=teacher_model_endpoint_connection, student_model=student_model, training_data=training_data, validation_data=validation_data, teacher_model_settings=teacher_model_settings, prompt_settings=prompt_settings, hyperparameters=hyperparameters, resources=resources, experiment_name=experiment_name, **kwargs, )