about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py108
1 files changed, 108 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
new file mode 100644
index 00000000..70f286a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/parallel_component.py
@@ -0,0 +1,108 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.parallel_task import ComponentParallelTaskSchema
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.component.retry_settings import RetrySettingsSchema
+from azure.ai.ml._schema.core.fields import DumpableEnumField, FileRefField, NestedField, StringTransformedEnum
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LoggingLevel
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class ParallelComponentSchema(ComponentSchema):
+    type = StringTransformedEnum(allowed_values=[NodeType.PARALLEL], required=True)
+    resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+    logging_level = DumpableEnumField(
+        allowed_values=[LoggingLevel.DEBUG, LoggingLevel.INFO, LoggingLevel.WARN],
+        dump_default=LoggingLevel.INFO,
+        metadata={
+            "description": "A string of the logging level name, which is defined in 'logging'. \
+            Possible values are 'WARNING', 'INFO', and 'DEBUG'."
+        },
+    )
+    task = NestedField(ComponentParallelTaskSchema, unknown=INCLUDE)
+    mini_batch_size = fields.Str(
+        metadata={"description": "The The batch size of current job."},
+    )
+    partition_keys = fields.List(
+        fields.Str(), metadata={"description": "The keys used to partition input data into mini-batches"}
+    )
+
+    input_data = fields.Str()
+    retry_settings = NestedField(RetrySettingsSchema, unknown=INCLUDE)
+    max_concurrency_per_instance = fields.Integer(
+        dump_default=1,
+        metadata={"description": "The max parallellism that each compute instance has."},
+    )
+    error_threshold = fields.Integer(
+        dump_default=-1,
+        metadata={
+            "description": "The number of item processing failures should be ignored. \
+            If the error_threshold is reached, the job terminates. \
+            For a list of files as inputs, one item means one file reference. \
+            This setting doesn't apply to command parallelization."
+        },
+    )
+    mini_batch_error_threshold = fields.Integer(
+        dump_default=-1,
+        metadata={
+            "description": "The number of mini batch processing failures should be ignored. \
+            If the mini_batch_error_threshold is reached, the job terminates. \
+            For a list of files as inputs, one item means one file reference. \
+            This setting can be used by either command or python function parallelization. \
+            Only one error_threshold setting can be used in one job."
+        },
+    )
+
+
+class RestParallelComponentSchema(ParallelComponentSchema):
+    """When component load from rest, won't validate on name since there might be existing component with invalid
+    name."""
+
+    name = fields.Str(required=True)
+
+
+class AnonymousParallelComponentSchema(AnonymousAssetSchema, ParallelComponentSchema):
+    """Anonymous parallel component schema.
+
+    Note inheritance follows order: AnonymousAssetSchema, ParallelComponentSchema because we need name and version to be
+    dump_only(marshmallow collects fields follows method resolution order).
+    """
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities._component.parallel_component import ParallelComponent
+
+        return ParallelComponent(
+            base_path=self.context[BASE_PATH_CONTEXT_KEY],
+            _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
+            **data,
+        )
+
+
+class ParallelComponentFileRefField(FileRefField):
+    def _deserialize(self, value, attr, data, **kwargs):
+        # Get component info from component yaml file.
+        data = super()._deserialize(value, attr, data, **kwargs)
+        component_dict = yaml.safe_load(data)
+        source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+        # Update base_path to parent path of component file.
+        component_schema_context = deepcopy(self.context)
+        component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+        component = AnonymousParallelComponentSchema(context=component_schema_context).load(
+            component_dict, unknown=INCLUDE
+        )
+        component._source_path = source_path
+        component._source = ComponentSource.YAML_COMPONENT
+        return component