diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py new file mode 100644 index 00000000..70f286a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py @@ -0,0 +1,108 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_load + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema +from azure.ai.ml._schema.component.resource import ComponentResourceSchema +from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema +from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel +from azure.ai.ml.constants._component import ComponentSource, NodeType + + +class ParallelComponentSchema(ComponentSchema): + type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True) + resources = NestedField(ComponentResourceSchema, unknown=INCLUDE) + logging_level = DumpableEnumField( + allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN], + dump_default=LoggingLevel.INFO, + metadata={ + "description": "A string of the logging level name, which is defined in 'logging'. \ + Possible values are 'WARNING', 'INFO', and 'DEBUG'." + }, + ) + task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE) + mini_batch_size = fields.Str( + metadata={"description": "The The batch size of current job."}, + ) + partition_keys = fields.List( + fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"} + ) + + input_data = fields.Str() + retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE) + max_concurrency_per_instance = fields.Integer( + dump_default=1, + metadata={"description": "The max parallellism that each compute instance has."}, + ) + error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": "The number of item processing failures should be ignored. \ + If the error_threshold is reached, the job terminates. \ + For a list of files as inputs, one item means one file reference. \ + This setting doesn't apply to command parallelization." + }, + ) + mini_batch_error_threshold = fields.Integer( + dump_default=-1, + metadata={ + "description": "The number of mini batch processing failures should be ignored. \ + If the mini_batch_error_threshold is reached, the job terminates. \ + For a list of files as inputs, one item means one file reference. \ + This setting can be used by either command or python function parallelization. \ + Only one error_threshold setting can be used in one job." + }, + ) + + +class RestParallelComponentSchema(ParallelComponentSchema): + """When component load from rest, won't validate on name since there might be existing component with invalid + name.""" + + name = fields.Str(required=True) + + +class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema): + """Anonymous parallel component schema. + + Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.parallel_component import ParallelComponent + + return ParallelComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +class ParallelComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousParallelComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component |