diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation')
5 files changed, 186 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py new file mode 100644 index 00000000..437d8743 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/__init__.py @@ -0,0 +1,17 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from .distillation_job import DistillationJobSchema +from .endpoint_request_settings import EndpointRequestSettingsSchema +from .prompt_settings import PromptSettingsSchema +from .teacher_model_settings import TeacherModelSettingsSchema + +__all__ = [ + "DistillationJobSchema", + "PromptSettingsSchema", + "EndpointRequestSettingsSchema", + "TeacherModelSettingsSchema", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py new file mode 100644 index 00000000..d72f2457 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/distillation_job.py @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema._distillation.prompt_settings import PromptSettingsSchema +from azure.ai.ml._schema._distillation.teacher_model_settings import TeacherModelSettingsSchema +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + LocalPathField, + NestedField, + RegistryStr, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.job import BaseJobSchema +from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema +from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField +from azure.ai.ml._schema.job_resource_configuration import ResourceConfigurationSchema +from azure.ai.ml._schema.workspace.connections import ServerlessConnectionSchema, WorkspaceConnectionSchema +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants import DataGenerationTaskType, DataGenerationType, JobType +from azure.ai.ml.constants._common import AzureMLResourceType + + +@experimental +class DistillationJobSchema(BaseJobSchema): + type = StringTransformedEnum(required=True, allowed_values=JobType.DISTILLATION) + data_generation_type = StringTransformedEnum( + allowed_values=[DataGenerationType.LABEL_GENERATION, DataGenerationType.DATA_GENERATION], + required=True, + ) + data_generation_task_type = StringTransformedEnum( + allowed_values=[ + DataGenerationTaskType.NLI, + DataGenerationTaskType.NLU_QA, + DataGenerationTaskType.CONVERSATION, + DataGenerationTaskType.MATH, + DataGenerationTaskType.SUMMARIZATION, + ], + casing_transform=str.upper, + required=True, + ) + teacher_model_endpoint_connection = UnionField( + [NestedField(WorkspaceConnectionSchema), NestedField(ServerlessConnectionSchema)], required=True + ) + student_model = UnionField( + [ + NestedField(ModelInputSchema), + RegistryStr(azureml_type=AzureMLResourceType.MODEL), + ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL, allow_default_version=True), + ], + required=True, + ) + training_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + validation_data = UnionField( + [ + NestedField(DataInputSchema), + ArmVersionedStr(azureml_type=AzureMLResourceType.DATA), + fields.Str(metadata={"pattern": r"^(http(s)?):.*"}), + fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}), + LocalPathField(pattern=r"^file:.*"), + LocalPathField( + pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*", + ), + ] + ) + teacher_model_settings = NestedField(TeacherModelSettingsSchema) + prompt_settings = NestedField(PromptSettingsSchema) + hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) + resources = NestedField(ResourceConfigurationSchema) + outputs = OutputsField() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py new file mode 100644 index 00000000..960e7d2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/endpoint_request_settings.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class EndpointRequestSettingsSchema(metaclass=PatchedSchemaMeta): + request_batch_size = fields.Int() + min_endpoint_success_ratio = fields.Number() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: EndpointRequestSettings made from the yaml + :rtype: EndpointRequestSettings + """ + from azure.ai.ml.entities._job.distillation.endpoint_request_settings import EndpointRequestSettings + + return EndpointRequestSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py new file mode 100644 index 00000000..3b21908a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/prompt_settings.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class PromptSettingsSchema(metaclass=PatchedSchemaMeta): + enable_chain_of_thought = fields.Bool() + enable_chain_of_density = fields.Bool() + max_len_summary = fields.Int() + # custom_prompt = fields.Str() + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: PromptSettings made from the yaml + :rtype: PromptSettings + """ + from azure.ai.ml.entities._job.distillation.prompt_settings import PromptSettings + + return PromptSettings(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py new file mode 100644 index 00000000..ecf32047 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_distillation/teacher_model_settings.py @@ -0,0 +1,29 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields, post_load + +from azure.ai.ml._schema._distillation.endpoint_request_settings import EndpointRequestSettingsSchema +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class TeacherModelSettingsSchema(metaclass=PatchedSchemaMeta): + inference_parameters = fields.Dict(keys=fields.Str(), values=fields.Raw()) + endpoint_request_settings = NestedField(EndpointRequestSettingsSchema) + + @post_load + def make(self, data, **kwargs): # pylint: disable=unused-argument + """Post-load processing of the schema data + + :param data: Dictionary of parsed values from the yaml. + :type data: typing.Dict + :return: TeacherModelSettings made from the yaml + :rtype: TeacherModelSettings + """ + from azure.ai.ml.entities._job.distillation.teacher_model_settings import TeacherModelSettings + + return TeacherModelSettings(**data) |