aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from marshmallow import fields

from ..._schema import NestedField
from ..._schema.core.fields import DumpableEnumField, EnvironmentField
from ..._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema
from ..._schema.job.job_limits import CommandJobLimitsSchema
from .._schema.node import InternalBaseNodeSchema, NodeType


class CommandSchema(InternalBaseNodeSchema, ParameterizedCommandSchema):
    class Meta:
        exclude = ["code", "distribution"]  # internal command doesn't have code & distribution

    environment = EnvironmentField()
    type = DumpableEnumField(allowed_values=[NodeType.COMMAND])
    limits = NestedField(CommandJobLimitsSchema)


class DistributedSchema(CommandSchema):
    class Meta:
        exclude = ["code"]  # need to enable distribution comparing to CommandSchema

    type = DumpableEnumField(allowed_values=[NodeType.DISTRIBUTED])


class ParallelSchema(InternalBaseNodeSchema, ParameterizedParallelSchema):
    class Meta:
        # partition_keys can still be used with unknown warning, but need to do dump before setting
        exclude = ["task", "input_data", "mini_batch_error_threshold", "partition_keys"]

    type = DumpableEnumField(allowed_values=[NodeType.PARALLEL])
    compute = fields.Str()
    environment = fields.Str()
    limits = NestedField(CommandJobLimitsSchema)