diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py new file mode 100644 index 00000000..848220d3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py @@ -0,0 +1,107 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from marshmallow import fields + +from azure.ai.ml._schema import YamlFileSchema +from azure.ai.ml._schema.component import ComponentSchema +from azure.ai.ml._schema.component.component import ComponentNameStr +from azure.ai.ml._schema.core.fields import ( + ArmVersionedStr, + EnvironmentField, + LocalPathField, + NestedField, + StringTransformedEnum, + UnionField, +) +from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._component import NodeType + + +class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta): + """Schema to recognize metadata of a flow as a component.""" + + name = ComponentNameStr() + version = fields.Str() + display_name = fields.Str() + description = fields.Str() + tags = fields.Dict(keys=fields.Str(), values=fields.Str()) + + +class _FlowAttributesSchema(metaclass=PatchedSchemaMeta): + """Schema to recognize attributes of a flow.""" + + variant = fields.Str() + column_mappings = fields.Dict( + fields.Str(), + fields.Str(), + ) + connections = fields.Dict( + keys=fields.Str(), + values=fields.Dict( + keys=fields.Str(), + values=fields.Str(), + ), + ) + environment_variables = fields.Dict( + fields.Str(), + fields.Str(), + ) + + +class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta): + environment = EnvironmentField() + is_deterministic = fields.Bool() + + +class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta): + # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY + azureml = NestedField(_FLowComponentOverridesSchema) + + +class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema): + """Schema for flow.dag.yaml file.""" + + environment_variables = fields.Dict( + fields.Str(), + fields.Str(), + ) + additional_includes = fields.List(LocalPathField()) + + +class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema): + """Schema for run.yaml file.""" + + flow = LocalPathField(required=True) + + +class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema): + """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow.""" + + class Meta: + """Override this to exclude inputs & outputs as component doesn't have them.""" + + exclude = ["inputs", "outputs"] # component doesn't have inputs & outputs + + # TODO: name should be required? + name = ComponentNameStr() + + type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True) + + # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema + properties = fields.Dict( + fields.Str(), + fields.Str(), + ) + + # this is different from regular CodeField + code = UnionField( + [ + LocalPathField(), + ArmVersionedStr(azureml_type=AzureMLResourceType.CODE), + ], + metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."}, + ) + additional_includes = fields.List(LocalPathField(), load_only=True) |