aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import INCLUDE, fields, post_load, pre_dump

from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
from ..._schema.core.fields import DumpableEnumField
from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
from ...constants._common import AzureMLResourceType
from .component import InternalComponentSchema, NodeType


class InternalBaseNodeSchema(BaseNodeSchema):
    class Meta:
        unknown = INCLUDE

    component = UnionField(
        [
            # for registry type assets
            RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
            # existing component
            ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
            # inline component or component file reference starting with FILE prefix
            NestedField(InternalComponentSchema, unknown=INCLUDE),
        ],
        required=True,
    )
    type = DumpableEnumField(
        allowed_values=NodeType.all_values(),
    )

    @post_load
    def make(self, data, **kwargs):  # pylint: disable=unused-argument
        from ...entities._builders import parse_inputs_outputs

        # parse inputs/outputs
        data = parse_inputs_outputs(data)

        # dict to node object
        from ...entities._job.pipeline._load_component import pipeline_node_factory

        return pipeline_node_factory.load_from_dict(data=data)

    @pre_dump
    def resolve_inputs_outputs(self, job, **kwargs):  # pylint: disable=unused-argument
        return _resolve_inputs_outputs(job)


class ScopeSchema(InternalBaseNodeSchema):
    type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
    adla_account_name = fields.Str(required=True)
    scope_param = fields.Str()
    custom_job_name_suffix = fields.Str()
    priority = fields.Int()
    auto_token = fields.Int()
    tokens = fields.Int()
    vcp = fields.Float()


class HDInsightSchema(InternalBaseNodeSchema):
    type = DumpableEnumField(allowed_values=[NodeType.HDI])

    compute_name = fields.Str()
    queue = fields.Str()
    driver_memory = fields.Str()
    driver_cores = fields.Int()
    executor_memory = fields.Str()
    executor_cores = fields.Int()
    number_executors = fields.Int()
    conf = UnionField(
        # dictionary or json string
        union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
    )
    hdinsight_spark_job_name = fields.Str()