aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
new file mode 100644
index 00000000..70f286a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
@@ -0,0 +1,108 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ParallelComponentSchema(ComponentSchema):
+ type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True)
+ resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+ logging_level = DumpableEnumField(
+ allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN],
+ dump_default=LoggingLevel.INFO,
+ metadata={
+ "description": "A string of the logging level name, which is defined in 'logging'. \
+ Possible values are 'WARNING', 'INFO', and 'DEBUG'."
+ },
+ )
+ task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE)
+ mini_batch_size = fields.Str(
+ metadata={"description": "The The batch size of current job."},
+ )
+ partition_keys = fields.List(
+ fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"}
+ )
+
+ input_data = fields.Str()
+ retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
+ max_concurrency_per_instance = fields.Integer(
+ dump_default=1,
+ metadata={"description": "The max parallellism that each compute instance has."},
+ )
+ error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of item processing failures should be ignored. \
+ If the error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting doesn't apply to command parallelization."
+ },
+ )
+ mini_batch_error_threshold = fields.Integer(
+ dump_default=-1,
+ metadata={
+ "description": "The number of mini batch processing failures should be ignored. \
+ If the mini_batch_error_threshold is reached, the job terminates. \
+ For a list of files as inputs, one item means one file reference. \
+ This setting can be used by either command or python function parallelization. \
+ Only one error_threshold setting can be used in one job."
+ },
+ )
+
+
+class RestParallelComponentSchema(ParallelComponentSchema):
+ """When component load from rest, won't validate on name since there might be existing component with invalid
+ name."""
+
+ name = fields.Str(required=True)
+
+
+class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema):
+ """Anonymous parallel component schema.
+
+ Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be
+ dump_only(marshmallow collects fields follows method resolution order).
+ """
+
+ @post_load
+ def make(self, data, **kwargs):
+ from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+ return ParallelComponent(
+ base_path=self.context[BASE_PATH_CONTEXT_KEY],
+ _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+ **data,
+ )
+
+
+class ParallelComponentFileRefField(FileRefField):
+ def _deserialize(self, value, attr, data, **kwargs):
+ # Get component info from component yaml file.
+ data = super()._deserialize(value, attr, data, **kwargs)
+ component_dict = yaml.safe_load(data)
+ source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+ # Update base_path to parent path of component file.
+ component_schema_context = deepcopy(self.context)
+ component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+ component = AnonymousParallelComponentSchema(context=component_schema_context).load(
+ component_dict, unknown=INCLUDE
+ )
+ component._source_path = source_path
+ component._source = ComponentSource.YAML_COMPONENT
+ return component