aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import fields

from azure.ai.ml._schema import YamlFileSchema
from azure.ai.ml._schema.component import ComponentSchema
from azure.ai.ml._schema.component.component import ComponentNameStr
from azure.ai.ml._schema.core.fields import (
    ArmVersionedStr,
    EnvironmentField,
    LocalPathField,
    NestedField,
    StringTransformedEnum,
    UnionField,
)
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType


class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta):
    """Schema to recognize metadata of a flow as a component."""

    name = ComponentNameStr()
    version = fields.Str()
    display_name = fields.Str()
    description = fields.Str()
    tags = fields.Dict(keys=fields.Str(), values=fields.Str())


class _FlowAttributesSchema(metaclass=PatchedSchemaMeta):
    """Schema to recognize attributes of a flow."""

    variant = fields.Str()
    column_mappings = fields.Dict(
        fields.Str(),
        fields.Str(),
    )
    connections = fields.Dict(
        keys=fields.Str(),
        values=fields.Dict(
            keys=fields.Str(),
            values=fields.Str(),
        ),
    )
    environment_variables = fields.Dict(
        fields.Str(),
        fields.Str(),
    )


class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta):
    environment = EnvironmentField()
    is_deterministic = fields.Bool()


class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
    # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY
    azureml = NestedField(_FLowComponentOverridesSchema)


class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
    """Schema for flow.dag.yaml file."""

    environment_variables = fields.Dict(
        fields.Str(),
        fields.Str(),
    )
    additional_includes = fields.List(LocalPathField())


class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema):
    """Schema for run.yaml file."""

    flow = LocalPathField(required=True)


class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema):
    """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow."""

    class Meta:
        """Override this to exclude inputs & outputs as component doesn't have them."""

        exclude = ["inputs", "outputs"]  # component doesn't have inputs & outputs

    # TODO: name should be required?
    name = ComponentNameStr()

    type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True)

    # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema
    properties = fields.Dict(
        fields.Str(),
        fields.Str(),
    )

    # this is different from regular CodeField
    code = UnionField(
        [
            LocalPathField(),
            ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
        ],
        metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
    )
    additional_includes = fields.List(LocalPathField(), load_only=True)