about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py19
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py54
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py18
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py17
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py35
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py73
7 files changed, 237 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py
new file mode 100644
index 00000000..e47aa230
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/__init__.py
@@ -0,0 +1,19 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+from .azure_openai_finetuning import AzureOpenAIFineTuningSchema
+from .azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema
+from .custom_model_finetuning import CustomModelFineTuningSchema
+from .finetuning_job import FineTuningJobSchema
+from .finetuning_vertical import FineTuningVerticalSchema
+
+__all__ = [
+    "AzureOpenAIFineTuningSchema",
+    "AzureOpenAIHyperparametersSchema",
+    "CustomModelFineTuningSchema",
+    "FineTuningJobSchema",
+    "FineTuningVerticalSchema",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py
new file mode 100644
index 00000000..f6d2a58d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_finetuning.py
@@ -0,0 +1,54 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+from marshmallow import post_load
+
+
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider
+from azure.ai.ml._schema._finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparametersSchema
+from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema
+from azure.ai.ml.entities._job.finetuning.azure_openai_hyperparameters import AzureOpenAIHyperparameters
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml._schema.core.fields import NestedField
+from azure.ai.ml.constants._job.finetuning import FineTuningConstants
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIFineTuningSchema(FineTuningVerticalSchema):
+    # This is meant to match the yaml definition NOT the models defined in _restclient
+
+    model_provider = StringTransformedEnum(
+        required=True, allowed_values=ModelProvider.AZURE_OPEN_AI, casing_transform=camel_to_snake
+    )
+    hyperparameters = NestedField(AzureOpenAIHyperparametersSchema(), data_key=FineTuningConstants.HyperParameters)
+
+    @post_load
+    def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]:
+        """Post load processing for the schema.
+
+        :param data: Dictionary of parsed values from the yaml.
+        :type data: typing.Dict
+
+        :return Dictionary of parsed values from the yaml.
+        :rtype Dict[str, Any]
+        """
+        data.pop("model_provider")
+        hyperaparameters = data.pop("hyperparameters", None)
+
+        if hyperaparameters and not isinstance(hyperaparameters, AzureOpenAIHyperparameters):
+            hyperaparameters_dict = {}
+            for key, value in hyperaparameters.items():
+                hyperaparameters_dict[key] = value
+            azure_openai_hyperparameters = AzureOpenAIHyperparameters(
+                batch_size=hyperaparameters_dict.get("batch_size", None),
+                learning_rate_multiplier=hyperaparameters_dict.get("learning_rate_multiplier", None),
+                n_epochs=hyperaparameters_dict.get("n_epochs", None),
+            )
+            data["hyperparameters"] = azure_openai_hyperparameters
+        return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py
new file mode 100644
index 00000000..f421188d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/azure_openai_hyperparameters.py
@@ -0,0 +1,18 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class AzureOpenAIHyperparametersSchema(metaclass=PatchedSchemaMeta):
+    n_epochs = fields.Int()
+    learning_rate_multiplier = fields.Float()
+    batch_size = fields.Int()
+    # TODO: Should be dict<string,string>, check schema for the same.
+    # For now not exposing as we dont have REST layer representation exposed.
+    # Need to check with the team.
+    # additional_parameters = fields.Dict()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py
new file mode 100644
index 00000000..3e14dca4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/constants.py
@@ -0,0 +1,17 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+
+class SnakeCaseFineTuningTaskTypes:
+    CHAT_COMPLETION = "chat_completion"
+    TEXT_COMPLETION = "text_completion"
+    TEXT_CLASSIFICATION = "text_classification"
+    QUESTION_ANSWERING = "question_answering"
+    TEXT_SUMMARIZATION = "text_summarization"
+    TOKEN_CLASSIFICATION = "token_classification"
+    TEXT_TRANSLATION = "text_translation"
+    IMAGE_CLASSIFICATION = "image_classification"
+    IMAGE_INSTANCE_SEGMENTATION = "image_instance_segmentation"
+    IMAGE_OBJECT_DETECTION = "image_object_detection"
+    VIDEO_MULTI_OBJECT_TRACKING = "video_multi_object_tracking"
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py
new file mode 100644
index 00000000..9d5b22a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/custom_model_finetuning.py
@@ -0,0 +1,35 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument
+
+from typing import Any, Dict
+from marshmallow import fields, post_load
+
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ModelProvider
+from azure.ai.ml._schema._finetuning.finetuning_vertical import FineTuningVerticalSchema
+from azure.ai.ml._schema.core.fields import StringTransformedEnum
+from azure.ai.ml._utils._experimental import experimental
+
+
+@experimental
+class CustomModelFineTuningSchema(FineTuningVerticalSchema):
+    # This is meant to match the yaml definition NOT the models defined in _restclient
+
+    model_provider = StringTransformedEnum(required=True, allowed_values=ModelProvider.CUSTOM)
+    hyperparameters = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True))
+
+    @post_load
+    def post_load_processing(self, data: Dict, **kwargs) -> Dict[str, Any]:
+        """Post-load processing for the schema.
+
+        :param data: Dictionary of parsed values from the yaml.
+        :type data: typing.Dict
+
+        :return Dictionary of parsed values from the yaml.
+        :rtype Dict[str, Any]
+        """
+
+        data.pop("model_provider")
+        return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py
new file mode 100644
index 00000000..e1b2270e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_job.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from azure.ai.ml._schema.job import BaseJobSchema
+from azure.ai.ml._schema.job.input_output_fields_provider import OutputsField
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._schema.core.fields import (
+    NestedField,
+)
+from ..queue_settings import QueueSettingsSchema
+from ..job_resources import JobResourcesSchema
+
+# This is meant to match the yaml definition NOT the models defined in _restclient
+
+
+@experimental
+class FineTuningJobSchema(BaseJobSchema):
+    outputs = OutputsField()
+    queue_settings = NestedField(QueueSettingsSchema)
+    resources = NestedField(JobResourcesSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py
new file mode 100644
index 00000000..10ac51ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_finetuning/finetuning_vertical.py
@@ -0,0 +1,73 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema._finetuning.finetuning_job import FineTuningJobSchema
+from azure.ai.ml._schema._finetuning.constants import SnakeCaseFineTuningTaskTypes
+from azure.ai.ml._schema.core.fields import (
+    ArmVersionedStr,
+    LocalPathField,
+    NestedField,
+    StringTransformedEnum,
+    UnionField,
+)
+from azure.ai.ml.constants import JobType
+from azure.ai.ml._utils.utils import snake_to_camel
+from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, ModelInputSchema
+from azure.ai.ml.constants._job.finetuning import FineTuningConstants
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml.constants._common import AzureMLResourceType
+
+
+# This is meant to match the yaml definition NOT the models defined in _restclient
+
+
+@experimental
+class FineTuningVerticalSchema(FineTuningJobSchema):
+    type = StringTransformedEnum(required=True, allowed_values=JobType.FINE_TUNING)
+    model = NestedField(ModelInputSchema, required=True)
+    training_data = UnionField(
+        [
+            NestedField(DataInputSchema),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+            fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+            fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+            LocalPathField(pattern=r"^file:.*"),
+            LocalPathField(
+                pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+            ),
+        ]
+    )
+    validation_data = UnionField(
+        [
+            NestedField(DataInputSchema),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.DATA),
+            fields.Str(metadata={"pattern": r"^(http(s)?):.*"}),
+            fields.Str(metadata={"pattern": r"^(wasb(s)?):.*"}),
+            LocalPathField(pattern=r"^file:.*"),
+            LocalPathField(
+                pattern=r"^(?!(azureml|http(s)?|wasb(s)?|file):).*",
+            ),
+        ]
+    )
+
+    task = StringTransformedEnum(
+        allowed_values=[
+            SnakeCaseFineTuningTaskTypes.CHAT_COMPLETION,
+            SnakeCaseFineTuningTaskTypes.TEXT_COMPLETION,
+            SnakeCaseFineTuningTaskTypes.TEXT_CLASSIFICATION,
+            SnakeCaseFineTuningTaskTypes.QUESTION_ANSWERING,
+            SnakeCaseFineTuningTaskTypes.TEXT_SUMMARIZATION,
+            SnakeCaseFineTuningTaskTypes.TOKEN_CLASSIFICATION,
+            SnakeCaseFineTuningTaskTypes.TEXT_TRANSLATION,
+            SnakeCaseFineTuningTaskTypes.IMAGE_CLASSIFICATION,
+            SnakeCaseFineTuningTaskTypes.IMAGE_INSTANCE_SEGMENTATION,
+            SnakeCaseFineTuningTaskTypes.IMAGE_OBJECT_DETECTION,
+            SnakeCaseFineTuningTaskTypes.VIDEO_MULTI_OBJECT_TRACKING,
+        ],
+        casing_transform=snake_to_camel,
+        data_key=FineTuningConstants.TaskType,
+        required=True,
+    )