diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py new file mode 100644 index 00000000..445481ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/spark_component.py @@ -0,0 +1,79 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument,protected-access + +from copy import deepcopy + +import yaml +from marshmallow import INCLUDE, fields, post_dump, post_load + +from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema +from azure.ai.ml._schema.component.component import ComponentSchema +from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum +from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY +from azure.ai.ml.constants._component import ComponentSource, NodeType + +from ..job.parameterized_spark import ParameterizedSparkSchema + + +class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema): + type = StringTransformedEnum(allowed_values=[NodeType.SPARK]) + additional_includes = fields.List(fields.Str()) + + @post_dump + def remove_unnecessary_fields(self, component_schema_dict, **kwargs): + if ( + component_schema_dict.get("additional_includes") is not None + and len(component_schema_dict["additional_includes"]) == 0 + ): + component_schema_dict.pop("additional_includes") + return component_schema_dict + + +class RestSparkComponentSchema(SparkComponentSchema): + """When component load from rest, won't validate on name since there might + be existing component with invalid name.""" + + name = fields.Str(required=True) + + +class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema): + """Anonymous spark component schema. + + Note inheritance follows order: AnonymousAssetSchema, + SparkComponentSchema because we need name and version to be + dump_only(marshmallow collects fields follows method resolution + order). + """ + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._component.spark_component import SparkComponent + + # Inline component will have source=YAML.JOB + # As we only regard full separate component file as YAML.COMPONENT + return SparkComponent( + base_path=self.context[BASE_PATH_CONTEXT_KEY], + _source=kwargs.pop("_source", ComponentSource.YAML_JOB), + **data, + ) + + +class SparkComponentFileRefField(FileRefField): + def _deserialize(self, value, attr, data, **kwargs): + # Get component info from component yaml file. + data = super()._deserialize(value, attr, data, **kwargs) + component_dict = yaml.safe_load(data) + source_path = self.context[BASE_PATH_CONTEXT_KEY] / value + + # Update base_path to parent path of component file. + component_schema_context = deepcopy(self.context) + component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent + component = AnonymousSparkComponentSchema(context=component_schema_context).load( + component_dict, unknown=INCLUDE + ) + component._source_path = source_path + component._source = ComponentSource.YAML_COMPONENT + return component |