aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py111
2 files changed, 124 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py
new file mode 100644
index 00000000..3c35d28d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+
+from ._distillation import distillation
+
+__all__ = ["distillation", "EndpointRequestSettings", "PromptSettings", "TeacherModelSettings"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py
new file mode 100644
index 00000000..4227100a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py
@@ -0,0 +1,111 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+"""Entrypoint for creating Distillation task."""
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationType
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob
+from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+
+
+@experimental
+def distillation(
+ *,
+ experiment_name: str,
+ data_generation_type: str,
+ data_generation_task_type: str,
+ teacher_model_endpoint_connection: WorkspaceConnection,
+ student_model: Union[Input, str],
+ training_data: Optional[Union[Input, str]] = None,
+ validation_data: Optional[Union[Input, str]] = None,
+ teacher_model_settings: Optional[TeacherModelSettings] = None,
+ prompt_settings: Optional[PromptSettings] = None,
+ hyperparameters: Optional[Dict] = None,
+ resources: Optional[ResourceConfiguration] = None,
+ **kwargs: Any,
+) -> "DistillationJob":
+ """Function to create a Distillation job.
+
+ A distillation job is used to transfer knowledge from a teacher model to student model by a two step process of
+ generating synthetic data from the teacher model and then finetuning the student model with the generated
+ synthetic data.
+
+ :param experiment_name: The name of the experiment.
+ :type experiment_name: str
+ :param data_generation_type: The type of data generation to perform.
+
+ Acceptable values: label_generation
+ :type data_generation_type: str
+ :param data_generation_task_type: The type of data to generate
+
+ Acceptable values: NLI, NLU_QA, CONVERSATION, MATH, SUMMARIZATION
+ :type: data_generation_task_type: str
+ :param teacher_model_endpoint_connection: The kind of teacher model connection that includes the name, endpoint
+ url, and api_key.
+ :type: teacher_model_endpoint_connection: WorkspaceConnection
+ :param student_model: The model to train
+ :type student_model: typing.Union[Input, str]
+ :param training_data: The training data to use. Should contain the questions but not the labels, defaults to None
+ :type training_data: typing.Optional[typing.Union[Input, str]], optional
+ :param validation_data: The validation data to use. Should contain the questions but not the labels, defaults to
+ None
+ :type validation_data: typing.Optional[typing.Union[Input, str]], optional
+ :param teacher_model_settings: The settings for the teacher model. Accepts both the inference parameters and
+ endpoint settings, defaults to None
+
+ Acceptable keys for inference parameters: temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
+ stop
+ :type teacher_model_settings: typing.Optional[TeacherModelSettings], optional
+ :param prompt_settings: The settings for the prompt that affect the system prompt used for data generation,
+ defaults to None
+ :type prompt_settings: typing.Optional[PromptSettings], optional
+ :param hyperparameters: The hyperparameters to use for finetuning, defaults to None
+ :type hyperparameters: typing.Optional[typing.Dict], optional
+ :param resources: The compute resource to use for the data generation step in the distillation job, defaults to
+ None
+ :type resources: typing.Optional[ResourceConfiguration], optional
+ :raises ValueError: Raises ValueError if there is no training data and data generation type is 'label_generation'
+ :return: A DistillationJob to submit
+ :rtype: DistillationJob
+ """
+ if isinstance(student_model, str):
+ student_model = Input(type=AssetTypes.URI_FILE, path=student_model)
+ if isinstance(training_data, str):
+ training_data = Input(type=AssetTypes.URI_FILE, path=training_data)
+ if isinstance(validation_data, str):
+ validation_data = Input(type=AssetTypes.URI_FILE, path=validation_data)
+
+ if training_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION:
+ raise ValueError(
+ f"Training data can not be None when data generation type is set to "
+ f"{DataGenerationType.LABEL_GENERATION}."
+ )
+
+ if validation_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION:
+ raise ValueError(
+ f"Validation data can not be None when data generation type is set to "
+ f"{DataGenerationType.LABEL_GENERATION}."
+ )
+
+ return DistillationJob(
+ data_generation_type=data_generation_type,
+ data_generation_task_type=data_generation_task_type,
+ teacher_model_endpoint_connection=teacher_model_endpoint_connection,
+ student_model=student_model,
+ training_data=training_data,
+ validation_data=validation_data,
+ teacher_model_settings=teacher_model_settings,
+ prompt_settings=prompt_settings,
+ hyperparameters=hyperparameters,
+ resources=resources,
+ experiment_name=experiment_name,
+ **kwargs,
+ )