about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py84
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py27
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py29
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py29
5 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
new file mode 100644
index 00000000..437d8743
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .distillation_job import DistillationJobSchema
+from .endpoint_request_settings import EndpointRequestSettingsSchema
+from .prompt_settings import PromptSettingsSchema
+from .teacher_model_settings import TeacherModelSettingsSchema
+
+__all__ = [
+    "DistillationJobSchema",
+    "PromptSettingsSchema",
+    "EndpointRequestSettingsSchema",
+    "TeacherModelSettingsSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
new file mode 100644
index 00000000..d72f2457
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._distillation.prompt_settings import PromptSettingsSchema
+from azure.ai.ml._schema._distillation.teacher_model_settings import TeacherModelSettingsSchema
+from azure.ai.ml._schema.core.fields import (
+    ArmVersionedStr,
+    LocalPathField,
+    NestedField,
+    RegistryStr,
+    StringTransformedEnum,
+    UnionField,
+)
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._schema.job_resource_configuration import ResourceConfigurationSchema
+from azure.ai.ml._schema.workspace.connections import ServerlessConnectionSchema, WorkspaceConnectionSchema
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants import DataGenerationTaskType, DataGenerationType, JobType
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+@experimental
+class DistillationJobSchema(BaseJobSchema):
+    type = StringTransformedEnum(required=True, allowed_values=JobType.DISTILLATION)
+    data_generation_type = StringTransformedEnum(
+        allowed_values=[DataGenerationType.LABEL_GENERATION, DataGenerationType.DATA_GENERATION],
+        required=True,
+    )
+    data_generation_task_type = StringTransformedEnum(
+        allowed_values=[
+            DataGenerationTaskType.NLI,
+            DataGenerationTaskType.NLU_QA,
+            DataGenerationTaskType.CONVERSATION,
+            DataGenerationTaskType.MATH,
+            DataGenerationTaskType.SUMMARIZATION,
+        ],
+        casing_transform=str.upper,
+        required=True,
+    )
+    teacher_model_endpoint_connection = UnionField(
+        [NestedField(WorkspaceConnectionSchema), NestedField(ServerlessConnectionSchema)], required=True
+    )
+    student_model = UnionField(
+        [
+            NestedField(ModelInputSchema),
+            RegistryStr(azureml_type=AzureMLResourceType.MODEL),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True),
+        ],
+        required=True,
+    )
+    training_data = UnionField(
+        [
+            NestedField(DataInputSchema),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+            fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+            fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+            LocalPathField(pattern=r"^file:.*"),
+            LocalPathField(
+                pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+            ),
+        ]
+    )
+    validation_data = UnionField(
+        [
+            NestedField(DataInputSchema),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+            fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+            fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+            LocalPathField(pattern=r"^file:.*"),
+            LocalPathField(
+                pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+            ),
+        ]
+    )
+    teacher_model_settings = NestedField(TeacherModelSettingsSchema)
+    prompt_settings = NestedField(PromptSettingsSchema)
+    hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+    resources = NestedField(ResourceConfigurationSchema)
+    outputs = OutputsField()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
new file mode 100644
index 00000000..960e7d2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py
@@ -0,0 +1,27 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class EndpointRequestSettingsSchema(metaclass=PatchedSchemaMeta):
+    request_batch_size = fields.Int()
+    min_endpoint_success_ratio = fields.Number()
+
+    @post_load
+    def make(self, data, **kwargs):  # pylint: disable=unused-argument
+        """Post-load processing of the schema data
+
+        :param data: Dictionary of parsed values from the yaml.
+        :type data: typing.Dict
+        :return: EndpointRequestSettings made from the yaml
+        :rtype: EndpointRequestSettings
+        """
+        from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings
+
+        return EndpointRequestSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
new file mode 100644
index 00000000..3b21908a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class PromptSettingsSchema(metaclass=PatchedSchemaMeta):
+    enable_chain_of_thought = fields.Bool()
+    enable_chain_of_density = fields.Bool()
+    max_len_summary = fields.Int()
+    # custom_prompt = fields.Str()
+
+    @post_load
+    def make(self, data, **kwargs):  # pylint: disable=unused-argument
+        """Post-load processing of the schema data
+
+        :param data: Dictionary of parsed values from the yaml.
+        :type data: typing.Dict
+        :return: PromptSettings made from the yaml
+        :rtype: PromptSettings
+        """
+        from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings
+
+        return PromptSettings(**data)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
new file mode 100644
index 00000000..ecf32047
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py
@@ -0,0 +1,29 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_load
+
+from azure.ai.ml._schema._distillation.endpoint_request_settings import EndpointRequestSettingsSchema
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class TeacherModelSettingsSchema(metaclass=PatchedSchemaMeta):
+    inference_parameters = fields.Dict(keys=fields.Str(), values=fields.Raw())
+    endpoint_request_settings = NestedField(EndpointRequestSettingsSchema)
+
+    @post_load
+    def make(self, data, **kwargs):  # pylint: disable=unused-argument
+        """Post-load processing of the schema data
+
+        :param data: Dictionary of parsed values from the yaml.
+        :type data: typing.Dict
+        :return: TeacherModelSettings made from the yaml
+        :rtype: TeacherModelSettings
+        """
+        from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings
+
+        return TeacherModelSettings(**data)