about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
new file mode 100644
index 00000000..848220d3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/flow.py
@@ -0,0 +1,107 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from azure.ai.ml._schema import YamlFileSchema
+from azure.ai.ml._schema.component import ComponentSchema
+from azure.ai.ml._schema.component.component import ComponentNameStr
+from azure.ai.ml._schema.core.fields import (
+    ArmVersionedStr,
+    EnvironmentField,
+    LocalPathField,
+    NestedField,
+    StringTransformedEnum,
+    UnionField,
+)
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml.constants._common import AzureMLResourceType
+from azure.ai.ml.constants._component import NodeType
+
+
+class _ComponentMetadataSchema(metaclass=PatchedSchemaMeta):
+    """Schema to recognize metadata of a flow as a component."""
+
+    name = ComponentNameStr()
+    version = fields.Str()
+    display_name = fields.Str()
+    description = fields.Str()
+    tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+
+
+class _FlowAttributesSchema(metaclass=PatchedSchemaMeta):
+    """Schema to recognize attributes of a flow."""
+
+    variant = fields.Str()
+    column_mappings = fields.Dict(
+        fields.Str(),
+        fields.Str(),
+    )
+    connections = fields.Dict(
+        keys=fields.Str(),
+        values=fields.Dict(
+            keys=fields.Str(),
+            values=fields.Str(),
+        ),
+    )
+    environment_variables = fields.Dict(
+        fields.Str(),
+        fields.Str(),
+    )
+
+
+class _FLowComponentOverridesSchema(metaclass=PatchedSchemaMeta):
+    environment = EnvironmentField()
+    is_deterministic = fields.Bool()
+
+
+class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
+    # the field name must be the same as azure.ai.ml.constants._common.PROMPTFLOW_AZUREML_OVERRIDE_KEY
+    azureml = NestedField(_FLowComponentOverridesSchema)
+
+
+class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
+    """Schema for flow.dag.yaml file."""
+
+    environment_variables = fields.Dict(
+        fields.Str(),
+        fields.Str(),
+    )
+    additional_includes = fields.List(LocalPathField())
+
+
+class RunSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowAttributesSchema, _FlowComponentOverridableSchema):
+    """Schema for run.yaml file."""
+
+    flow = LocalPathField(required=True)
+
+
+class FlowComponentSchema(ComponentSchema, _FlowAttributesSchema, _FLowComponentOverridesSchema):
+    """FlowSchema and FlowRunSchema are used to load flow while FlowComponentSchema is used to dump flow."""
+
+    class Meta:
+        """Override this to exclude inputs & outputs as component doesn't have them."""
+
+        exclude = ["inputs", "outputs"]  # component doesn't have inputs & outputs
+
+    # TODO: name should be required?
+    name = ComponentNameStr()
+
+    type = StringTransformedEnum(allowed_values=[NodeType.FLOW_PARALLEL], required=True)
+
+    # name, version, tags, display_name and is_deterministic are inherited from ComponentSchema
+    properties = fields.Dict(
+        fields.Str(),
+        fields.Str(),
+    )
+
+    # this is different from regular CodeField
+    code = UnionField(
+        [
+            LocalPathField(),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.CODE),
+        ],
+        metadata={"description": "A local path or http:, https:, azureml: url pointing to a remote location."},
+    )
+    additional_includes = fields.List(LocalPathField(), load_only=True)