aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py148
1 files changed, 148 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py
new file mode 100644
index 00000000..4b815db7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py
@@ -0,0 +1,148 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+from typing import List
+
+from marshmallow import fields, post_dump, post_load, pre_dump
+
+from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
+from azure.ai.ml._schema.automl import AutoMLClassificationSchema, AutoMLForecastingSchema, AutoMLRegressionSchema
+from azure.ai.ml._schema.automl.image_vertical.image_classification import (
+ ImageClassificationMultilabelSchema,
+ ImageClassificationSchema,
+)
+from azure.ai.ml._schema.automl.image_vertical.image_object_detection import (
+ ImageInstanceSegmentationSchema,
+ ImageObjectDetectionSchema,
+)
+from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema
+from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import TextClassificationMultilabelSchema
+from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema
+from azure.ai.ml._schema.core.fields import ComputeField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, OutputSchema
+from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
+
+
+class AutoMLNodeMixin(PathAwareSchema):
+ """Inherit this mixin to change automl job schemas to automl node schema.
+
+ eg: Compute is required for automl job but not required for automl node in pipeline.
+ Note: Inherit this before BaseJobSchema to make sure optional takes affect.
+ """
+
+ def __init__(self, **kwargs):
+ super(AutoMLNodeMixin, self).__init__(**kwargs)
+ # update field objects and add data binding support, won't bind task & type as data binding
+ support_data_binding_expression_for_fields(self, attrs_to_skip=["task_type", "type"])
+
+ compute = ComputeField(required=False)
+ outputs = fields.Dict(
+ keys=fields.Str(),
+ values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True),
+ )
+
+ @pre_dump
+ def resolve_outputs(self, job: "AutoMLJob", **kwargs):
+ # Try resolve object's inputs & outputs and return a resolved new object
+ import copy
+
+ result = copy.copy(job)
+ result._outputs = job._build_outputs()
+ return result
+
+ @post_dump(pass_original=True)
+ # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype
+ def resolve_nested_data(self, job_dict: dict, job: "AutoMLJob", **kwargs):
+ """Resolve nested data into flatten format."""
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+
+ if not isinstance(job, AutoMLJob):
+ return job_dict
+ # change output to rest output dicts
+ job_dict["outputs"] = job._to_rest_outputs()
+ return job_dict
+
+ @post_load
+ def make(self, data, **kwargs):
+ data["task"] = data.pop("task_type")
+ return data
+
+
+class AutoMLClassificationNodeSchema(AutoMLNodeMixin, AutoMLClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLRegressionNodeSchema(AutoMLNodeMixin, AutoMLRegressionSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLForecastingNodeSchema(AutoMLNodeMixin, AutoMLForecastingSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextClassificationNode(AutoMLNodeMixin, TextClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextClassificationMultilabelNode(AutoMLNodeMixin, TextClassificationMultilabelSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class AutoMLTextNerNode(AutoMLNodeMixin, TextNerSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageClassificationMulticlassNodeSchema(AutoMLNodeMixin, ImageClassificationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageClassificationMultilabelNodeSchema(AutoMLNodeMixin, ImageClassificationMultilabelSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageObjectDetectionNodeSchema(AutoMLNodeMixin, ImageObjectDetectionSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+class ImageInstanceSegmentationNodeSchema(AutoMLNodeMixin, ImageInstanceSegmentationSchema):
+ training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+ validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)])
+
+
+def AutoMLNodeSchema(**kwargs) -> List[fields.Field]:
+ """Get the list of all nested schema for all AutoML nodes.
+
+ :return: The list of fields
+ :rtype: List[fields.Field]
+ """
+ return [
+ # region: automl node schemas
+ NestedField(AutoMLClassificationNodeSchema, **kwargs),
+ NestedField(AutoMLRegressionNodeSchema, **kwargs),
+ NestedField(AutoMLForecastingNodeSchema, **kwargs),
+ # Vision
+ NestedField(ImageClassificationMulticlassNodeSchema, **kwargs),
+ NestedField(ImageClassificationMultilabelNodeSchema, **kwargs),
+ NestedField(ImageObjectDetectionNodeSchema, **kwargs),
+ NestedField(ImageInstanceSegmentationNodeSchema, **kwargs),
+ # NLP
+ NestedField(AutoMLTextClassificationNode, **kwargs),
+ NestedField(AutoMLTextClassificationMultilabelNode, **kwargs),
+ NestedField(AutoMLTextNerNode, **kwargs),
+ # endregion
+ ]