about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py13
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py111
2 files changed, 124 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py
new file mode 100644
index 00000000..3c35d28d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/__init__.py
@@ -0,0 +1,13 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+
+from ._distillation import distillation
+
+__all__ = ["distillation", "EndpointRequestSettings", "PromptSettings", "TeacherModelSettings"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py
new file mode 100644
index 00000000..4227100a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/model_customization/_distillation.py
@@ -0,0 +1,111 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+"""Entrypoint for creating Distillation task."""
+from typing import Any, Dict, Optional, Union
+
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationType
+from azure.ai.ml.constants._common import AssetTypes
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob
+from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+
+
+@experimental
+def distillation(
+    *,
+    experiment_name: str,
+    data_generation_type: str,
+    data_generation_task_type: str,
+    teacher_model_endpoint_connection: WorkspaceConnection,
+    student_model: Union[Input, str],
+    training_data: Optional[Union[Input, str]] = None,
+    validation_data: Optional[Union[Input, str]] = None,
+    teacher_model_settings: Optional[TeacherModelSettings] = None,
+    prompt_settings: Optional[PromptSettings] = None,
+    hyperparameters: Optional[Dict] = None,
+    resources: Optional[ResourceConfiguration] = None,
+    **kwargs: Any,
+) -> "DistillationJob":
+    """Function to create a Distillation job.
+
+    A distillation job is used to transfer knowledge from a teacher model to student model by a two step process of
+    generating synthetic data from the teacher model and then finetuning the student model with the generated
+    synthetic data.
+
+    :param experiment_name: The name of the experiment.
+    :type experiment_name: str
+    :param data_generation_type: The type of data generation to perform.
+
+        Acceptable values: label_generation
+    :type data_generation_type: str
+    :param data_generation_task_type: The type of data to generate
+
+        Acceptable values: NLI, NLU_QA, CONVERSATION, MATH, SUMMARIZATION
+    :type: data_generation_task_type: str
+    :param teacher_model_endpoint_connection: The kind of teacher model connection that includes the name, endpoint
+        url, and api_key.
+    :type: teacher_model_endpoint_connection: WorkspaceConnection
+    :param student_model: The model to train
+    :type student_model: typing.Union[Input, str]
+    :param training_data: The training data to use. Should contain the questions but not the labels, defaults to None
+    :type training_data: typing.Optional[typing.Union[Input, str]], optional
+    :param validation_data: The validation data to use. Should contain the questions but not the labels, defaults to
+        None
+    :type validation_data: typing.Optional[typing.Union[Input, str]], optional
+    :param teacher_model_settings: The settings for the teacher model. Accepts both the inference parameters and
+        endpoint settings, defaults to None
+
+        Acceptable keys for inference parameters: temperature, max_tokens, top_p, frequency_penalty, presence_penalty,
+        stop
+    :type teacher_model_settings: typing.Optional[TeacherModelSettings], optional
+    :param prompt_settings: The settings for the prompt that affect the system prompt used for data generation,
+        defaults to None
+    :type prompt_settings: typing.Optional[PromptSettings], optional
+    :param hyperparameters: The hyperparameters to use for finetuning, defaults to None
+    :type hyperparameters: typing.Optional[typing.Dict], optional
+    :param resources: The compute resource to use for the data generation step in the distillation job, defaults to
+        None
+    :type resources: typing.Optional[ResourceConfiguration], optional
+    :raises ValueError: Raises ValueError if there is no training data and data generation type is 'label_generation'
+    :return: A DistillationJob to submit
+    :rtype: DistillationJob
+    """
+    if isinstance(student_model, str):
+        student_model = Input(type=AssetTypes.URI_FILE, path=student_model)
+    if isinstance(training_data, str):
+        training_data = Input(type=AssetTypes.URI_FILE, path=training_data)
+    if isinstance(validation_data, str):
+        validation_data = Input(type=AssetTypes.URI_FILE, path=validation_data)
+
+    if training_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION:
+        raise ValueError(
+            f"Training data can not be None when data generation type is set to "
+            f"{DataGenerationType.LABEL_GENERATION}."
+        )
+
+    if validation_data is None and data_generation_type == DataGenerationType.LABEL_GENERATION:
+        raise ValueError(
+            f"Validation data can not be None when data generation type is set to "
+            f"{DataGenerationType.LABEL_GENERATION}."
+        )
+
+    return DistillationJob(
+        data_generation_type=data_generation_type,
+        data_generation_task_type=data_generation_task_type,
+        teacher_model_endpoint_connection=teacher_model_endpoint_connection,
+        student_model=student_model,
+        training_data=training_data,
+        validation_data=validation_data,
+        teacher_model_settings=teacher_model_settings,
+        prompt_settings=prompt_settings,
+        hyperparameters=hyperparameters,
+        resources=resources,
+        experiment_name=experiment_name,
+        **kwargs,
+    )