diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py new file mode 100644 index 00000000..4b815db7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/pipeline/automl_node.py @@ -0,0 +1,148 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access +from typing import List + +from marshmallow import fields, post_dump, post_load, pre_dump + +from azure.ai.ml._schema._utils.data_binding_expression import support_data_binding_expression_for_fields +from azure.ai.ml._schema.automl import AutoMLClassificationSchema, AutoMLForecastingSchema, AutoMLRegressionSchema +from azure.ai.ml._schema.automl.image_vertical.image_classification import ( + ImageClassificationMultilabelSchema, + ImageClassificationSchema, +) +from azure.ai.ml._schema.automl.image_vertical.image_object_detection import ( + ImageInstanceSegmentationSchema, + ImageObjectDetectionSchema, +) +from azure.ai.ml._schema.automl.nlp_vertical.text_classification import TextClassificationSchema +from azure.ai.ml._schema.automl.nlp_vertical.text_classification_multilabel import TextClassificationMultilabelSchema +from azure.ai.ml._schema.automl.nlp_vertical.text_ner import TextNerSchema +from azure.ai.ml._schema.core.fields import ComputeField, NestedField, UnionField +from azure.ai.ml._schema.core.schema import PathAwareSchema +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, OutputSchema +from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr + + +class AutoMLNodeMixin(PathAwareSchema): + """Inherit this mixin to change automl job schemas to automl node schema. + + eg: Compute is required for automl job but not required for automl node in pipeline. + Note: Inherit this before BaseJobSchema to make sure optional takes affect. + """ + + def __init__(self, **kwargs): + super(AutoMLNodeMixin, self).__init__(**kwargs) + # update field objects and add data binding support, won't bind task & type as data binding + support_data_binding_expression_for_fields(self, attrs_to_skip=["task_type", "type"]) + + compute = ComputeField(required=False) + outputs = fields.Dict( + keys=fields.Str(), + values=UnionField([NestedField(OutputSchema), OutputBindingStr], allow_none=True), + ) + + @pre_dump + def resolve_outputs(self, job: "AutoMLJob", **kwargs): + # Try resolve object's inputs & outputs and return a resolved new object + import copy + + result = copy.copy(job) + result._outputs = job._build_outputs() + return result + + @post_dump(pass_original=True) + # pylint: disable-next=docstring-missing-param,docstring-missing-return,docstring-missing-rtype + def resolve_nested_data(self, job_dict: dict, job: "AutoMLJob", **kwargs): + """Resolve nested data into flatten format.""" + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + + if not isinstance(job, AutoMLJob): + return job_dict + # change output to rest output dicts + job_dict["outputs"] = job._to_rest_outputs() + return job_dict + + @post_load + def make(self, data, **kwargs): + data["task"] = data.pop("task_type") + return data + + +class AutoMLClassificationNodeSchema(AutoMLNodeMixin, AutoMLClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLRegressionNodeSchema(AutoMLNodeMixin, AutoMLRegressionSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLForecastingNodeSchema(AutoMLNodeMixin, AutoMLForecastingSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + test_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextClassificationNode(AutoMLNodeMixin, TextClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextClassificationMultilabelNode(AutoMLNodeMixin, TextClassificationMultilabelSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class AutoMLTextNerNode(AutoMLNodeMixin, TextNerSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageClassificationMulticlassNodeSchema(AutoMLNodeMixin, ImageClassificationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageClassificationMultilabelNodeSchema(AutoMLNodeMixin, ImageClassificationMultilabelSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageObjectDetectionNodeSchema(AutoMLNodeMixin, ImageObjectDetectionSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +class ImageInstanceSegmentationNodeSchema(AutoMLNodeMixin, ImageInstanceSegmentationSchema): + training_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + validation_data = UnionField([fields.Str(), NestedField(MLTableInputSchema)]) + + +def AutoMLNodeSchema(**kwargs) -> List[fields.Field]: + """Get the list of all nested schema for all AutoML nodes. + + :return: The list of fields + :rtype: List[fields.Field] + """ + return [ + # region: automl node schemas + NestedField(AutoMLClassificationNodeSchema, **kwargs), + NestedField(AutoMLRegressionNodeSchema, **kwargs), + NestedField(AutoMLForecastingNodeSchema, **kwargs), + # Vision + NestedField(ImageClassificationMulticlassNodeSchema, **kwargs), + NestedField(ImageClassificationMultilabelNodeSchema, **kwargs), + NestedField(ImageObjectDetectionNodeSchema, **kwargs), + NestedField(ImageInstanceSegmentationNodeSchema, **kwargs), + # NLP + NestedField(AutoMLTextClassificationNode, **kwargs), + NestedField(AutoMLTextClassificationMultilabelNode, **kwargs), + NestedField(AutoMLTextNerNode, **kwargs), + # endregion + ] |