aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py84
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py29
5 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
new file mode 100644
index 00000000..437d8743
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .distillation_job import DistillationJobSchema
+from .endpoint_request_settings import EndpointRequestSettingsSchema
+from .prompt_settings import PromptSettingsSchema
+from .teacher_model_settings import TeacherModelSettingsSchema
+
+__all__ = [
+ "DistillationJobSchema",
+ "PromptSettingsSchema",
+ "EndpointRequestSettingsSchema",
+ "TeacherModelSettingsSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
new file mode 100644
index 00000000..d72f2457
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._distillation.prompt_settings import PromptSettingsSchema
+from azure.ai.ml._schema._distillation.teacher_model_settings import TeacherModelSettingsSchema
+from azure.ai.ml._schema.core.fields import (
+ ArmVersionedStr,
+ LocalPathField,
+ NestedField,
+ RegistryStr,
+ StringTransformedEnum,
+ UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._schema.job_resource_configuration import ResourceConfigurationSchema
+from azure.ai.ml._schema.workspace.connections import ServerlessConnectionSchema, WorkspaceConnectionSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationTaskType, DataGenerationType, JobType
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+@experimental
+class DistillationJobSchema(BaseJobSchema):
+ type = StringTransformedEnum(required=True, allowed_values=JobType.DISTILLATION)
+ data_generation_type = StringTransformedEnum(
+ allowed_values=[DataGenerationType.LABEL_GENERATION, DataGenerationType.DATA_GENERATION],
+ required=True,
+ )
+ data_generation_task_type = StringTransformedEnum(
+ allowed_values=[
+ DataGenerationTaskType.NLI,
+ DataGenerationTaskType.NLU_QA,
+ DataGenerationTaskType.CONVERSATION,
+ DataGenerationTaskType.MATH,
+ DataGenerationTaskType.SUMMARIZATION,
+ ],
+ casing_transform=str.upper,
+ required=True,
+ )
+ teacher_model_endpoint_connection = UnionField(
+ [NestedField(WorkspaceConnectionSchema), NestedField(ServerlessConnectionSchema)], required=True
+ )
+ student_model = UnionField(
+ [
+ NestedField(ModelInputSchema),
+ RegistryStr(azureml_type=AzureMLResourceType.MODEL),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True),
+ ],
+ required=True,
+ )
+ training_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+ validation_data = UnionField(
+ [
+ NestedField(DataInputSchema),
+ ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+ fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+ fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+ LocalPathField(pattern=r"^file:.*"),
+ LocalPathField(
+ pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+ ),
+ ]
+ )
+ teacher_model_settings = NestedField(TeacherModelSettingsSchema)
+ prompt_settings = NestedField(PromptSettingsSchema)
+ hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+ resources = NestedField(ResourceConfigurationSchema)
+ outputs = OutputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
new file mode 100644
index 00000000..960e7d2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class EndpointRequestSettingsSchema(metaclass=PatchedSchemaMeta):
+ request_batch_size = fields.Int()
+ min_endpoint_success_ratio = fields.Number()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: EndpointRequestSettings made from the yaml
+ :rtype: EndpointRequestSettings
+ """
+ from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+
+ return EndpointRequestSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
new file mode 100644
index 00000000..3b21908a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class PromptSettingsSchema(metaclass=PatchedSchemaMeta):
+ enable_chain_of_thought = fields.Bool()
+ enable_chain_of_density = fields.Bool()
+ max_len_summary = fields.Int()
+ # custom_prompt = fields.Str()
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: PromptSettings made from the yaml
+ :rtype: PromptSettings
+ """
+ from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+
+ return PromptSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
new file mode 100644
index 00000000..ecf32047
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema._distillation.endpoint_request_settings import EndpointRequestSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class TeacherModelSettingsSchema(metaclass=PatchedSchemaMeta):
+ inference_parameters = fields.Dict(keys=fields.Str(), values=fields.Raw())
+ endpoint_request_settings = NestedField(EndpointRequestSettingsSchema)
+
+ @post_load
+ def make(self, data, **kwargs): # pylint: disable=unused-argument
+ """Post-load processing of the schema data
+
+ :param data: Dictionary of parsed values from the yaml.
+ :type data: typing.Dict
+ :return: TeacherModelSettings made from the yaml
+ :rtype: TeacherModelSettings
+ """
+ from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+
+ return TeacherModelSettings(**data)