diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py | 13 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py | 111 |
2 files changed, 124 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py new file mode 100644 index 00000000..3c35d28d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings +from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings +from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings + +from ._distillation import distillation + +__all__ = ["distillation", "EndpointRequestSettings", "PromptSettings", "TeacherModelSettings"] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py new file mode 100644 index 00000000..4227100a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +"""Entrypoint for creating Distillation task.""" +from typing import Any, Dict, Optional, Union + +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import DataGenerationType +from azure.ai.ml.constants._common import AssetTypes +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob +from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings +from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings +from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection + + +@experimental +def distillation( + *, + experiment_name: str, + data_generation_type: str, + data_generation_task_type: str, + teacher_model_endpoint_connection: WorkspaceConnection, + student_model: Union[Input, str], + training_data: Optional[Union[Input, str]] = None, + validation_data: Optional[Union[Input, str]] = None, + teacher_model_settings: Optional[TeacherModelSettings] = None, + prompt_settings: Optional[PromptSettings] = None, + hyperparameters: Optional[Dict] = None, + resources: Optional[ResourceConfiguration] = None, + **kwargs: Any, +) -> "DistillationJob": + """Function to create a Distillation job. + + A distillation job is used to transfer knowledge from a teacher model to student model by a two step process of + generating synthetic data from the teacher model and then finetuning the student model with the generated + synthetic data. + + :param experiment_name: The name of the experiment. + :type experiment_name: str + :param data_generation_type: The type of data generation to perform. + + Acceptable values: label_generation + :type data_generation_type: str + :param data_generation_task_type: The type of data to generate + + Acceptable values: NLI, NLU_QA, CONVERSATION, MATH, SUMMARIZATION + :type: data_generation_task_type: str + :param teacher_model_endpoint_connection: The kind of teacher model connection that includes the name, endpoint + url, and api_key. + :type: teacher_model_endpoint_connection: WorkspaceConnection + :param student_model: The model to train + :type student_model: typing.Union[Input, str] + :param training_data: The training data to use. Should contain the questions but not the labels, defaults to None + :type training_data: typing.Optional[typing.Union[Input, str]], optional + :param validation_data: The validation data to use. Should contain the questions but not the labels, defaults to + None + :type validation_data: typing.Optional[typing.Union[Input, str]], optional + :param teacher_model_settings: The settings for the teacher model. Accepts both the inference parameters and + endpoint settings, defaults to None + + Acceptable keys for inference parameters: temperature, max_tokens, top_p, frequency_penalty, presence_penalty, + stop + :type teacher_model_settings: typing.Optional[TeacherModelSettings], optional + :param prompt_settings: The settings for the prompt that affect the system prompt used for data generation, + defaults to None + :type prompt_settings: typing.Optional[PromptSettings], optional + :param hyperparameters: The hyperparameters to use for finetuning, defaults to None + :type hyperparameters: typing.Optional[typing.Dict], optional + :param resources: The compute resource to use for the data generation step in the distillation job, defaults to + None + :type resources: typing.Optional[ResourceConfiguration], optional + :raises ValueError: Raises ValueError if there is no training data and data generation type is 'label_generation' + :return: A DistillationJob to submit + :rtype: DistillationJob + """ + if isinstance(student_model, str): + student_model = Input(type=AssetTypes.URI_FILE, path=student_model) + if isinstance(training_data, str): + training_data = Input(type=AssetTypes.URI_FILE, path=training_data) + if isinstance(validation_data, str): + validation_data = Input(type=AssetTypes.URI_FILE, path=validation_data) + + if training_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION: + raise ValueError( + f"Training data can not be None when data generation type is set to " + f"{DataGenerationType.LABEL_GENERATION}." + ) + + if validation_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION: + raise ValueError( + f"Validation data can not be None when data generation type is set to " + f"{DataGenerationType.LABEL_GENERATION}." + ) + + return DistillationJob( + data_generation_type=data_generation_type, + data_generation_task_type=data_generation_task_type, + teacher_model_endpoint_connection=teacher_model_endpoint_connection, + student_model=student_model, + training_data=training_data, + validation_data=validation_data, + teacher_model_settings=teacher_model_settings, + prompt_settings=prompt_settings, + hyperparameters=hyperparameters, + resources=resources, + experiment_name=experiment_name, + **kwargs, + ) |