aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument,protected-access

from copy import deepcopy

import yaml
from marshmallow import INCLUDE, fields, post_dump, post_load

from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
from azure.ai.ml._schema.component.component import ComponentSchema
from azure.ai.ml._schema.core.fields import FileRefField, StringTransformedEnum
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
from azure.ai.ml.constants._component import ComponentSource, NodeType

from ..job.parameterized_spark import ParameterizedSparkSchema


class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
    type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
    additional_includes = fields.List(fields.Str())

    @post_dump
    def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
        if (
            component_schema_dict.get("additional_includes") is not None
            and len(component_schema_dict["additional_includes"]) == 0
        ):
            component_schema_dict.pop("additional_includes")
        return component_schema_dict


class RestSparkComponentSchema(SparkComponentSchema):
    """When component load from rest, won't validate on name since there might
    be existing component with invalid name."""

    name = fields.Str(required=True)


class AnonymousSparkComponentSchema(AnonymousAssetSchema, SparkComponentSchema):
    """Anonymous spark component schema.

    Note inheritance follows order: AnonymousAssetSchema,
    SparkComponentSchema because we need name and version to be
    dump_only(marshmallow collects fields follows method resolution
    order).
    """

    @post_load
    def make(self, data, **kwargs):
        from azure.ai.ml.entities._component.spark_component import SparkComponent

        # Inline component will have source=YAML.JOB
        # As we only regard full separate component file as YAML.COMPONENT
        return SparkComponent(
            base_path=self.context[BASE_PATH_CONTEXT_KEY],
            _source=kwargs.pop("_source", ComponentSource.YAML_JOB),
            **data,
        )


class SparkComponentFileRefField(FileRefField):
    def _deserialize(self, value, attr, data, **kwargs):
        # Get component info from component yaml file.
        data = super()._deserialize(value, attr, data, **kwargs)
        component_dict = yaml.safe_load(data)
        source_path = self.context[BASE_PATH_CONTEXT_KEY] / value

        # Update base_path to parent path of component file.
        component_schema_context = deepcopy(self.context)
        component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
        component = AnonymousSparkComponentSchema(context=component_schema_context).load(
            component_dict, unknown=INCLUDE
        )
        component._source_path = source_path
        component._source = ComponentSource.YAML_COMPONENT
        return component