# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from marshmallow import fields from azure.ai.ml._schema import YamlFileSchema from azure.ai.ml._schema.component import ComponentSchema from azure.ai.ml._schema.component.component import ComponentNameStr from azure.ai.ml._schema.core.fields import ( ArmVersionedStr, EnvironmentField, LocalPathField, NestedField, StringTransformedEnum, UnionField, ) from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta from azure.ai.ml.constants._common import AzureMLResourceType from azure.ai.ml.constants._component import NodeType class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta): """Schema to recognize metadata of a flow as a component.""" name = ComponentNameStr() version = fields.Str() display_name = fields.Str() description = fields.Str() tags = fields.Dict(keys=fields.Str(), values=fields.Str()) class _FlowAttributesSchema(metaclass=PatchedSchemaMeta): """Schema to recognize attributes of a flow.""" variant = fields.Str() column_mappings = fields.Dict( fields.Str(), fields.Str(), ) connections = fields.Dict( keys=fields.Str(), values=fields.Dict( keys=fields.Str(), values=fields.Str(), ), ) environment_variables = fields.Dict( fields.Str(), fields.Str(), ) class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta): environment = EnvironmentField() is_deterministic = fields.Bool() class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta): # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY azureml = NestedField(_FLowComponentOverridesSchema) class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema): """Schema for flow.dag.yaml file.""" environment_variables = fields.Dict( fields.Str(), fields.Str(), ) additional_includes = fields.List(LocalPathField()) class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema): """Schema for run.yaml file.""" flow = LocalPathField(required=True) class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema): """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow.""" class Meta: """Override this to exclude inputs & outputs as component doesn't have them.""" exclude = ["inputs", "outputs"] # component doesn't have inputs & outputs # TODO: name should be required? name = ComponentNameStr() type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True) # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema properties = fields.Dict( fields.Str(), fields.Str(), ) # this is different from regular CodeField code = UnionField( [ LocalPathField(), ArmVersionedStr(azureml_type=AzureMLResourceType.CODE), ], metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."}, ) additional_includes = fields.List(LocalPathField(), load_only=True)