about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
new file mode 100644
index 00000000..9d688ee0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/command_component.py
@@ -0,0 +1,137 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=unused-argument,protected-access
+from copy import deepcopy
+
+import yaml
+from marshmallow import INCLUDE, fields, post_dump, post_load
+
+from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
+from azure.ai.ml._schema.component.component import ComponentSchema
+from azure.ai.ml._schema.component.input_output import (
+    OutputPortSchema,
+    PrimitiveOutputSchema,
+)
+from azure.ai.ml._schema.component.resource import ComponentResourceSchema
+from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
+from azure.ai.ml._schema.core.fields import (
+    ExperimentalField,
+    FileRefField,
+    NestedField,
+    StringTransformedEnum,
+    UnionField,
+)
+from azure.ai.ml._schema.job.distribution import (
+    MPIDistributionSchema,
+    PyTorchDistributionSchema,
+    TensorFlowDistributionSchema,
+    RayDistributionSchema,
+)
+from azure.ai.ml._schema.job.parameterized_command import ParameterizedCommandSchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureDevopsArtifactsType
+from azure.ai.ml.constants._component import ComponentSource, NodeType
+
+
+class AzureDevopsArtifactsSchema(metaclass=PatchedSchemaMeta):
+    type = StringTransformedEnum(allowed_values=[AzureDevopsArtifactsType.ARTIFACT])
+    feed = fields.Str()
+    name = fields.Str()
+    version = fields.Str()
+    scope = fields.Str()
+    organization = fields.Str()
+    project = fields.Str()
+
+
+class CommandComponentSchema(ComponentSchema, ParameterizedCommandSchema):
+    class Meta:
+        exclude = ["environment_variables"]  # component doesn't have environment variables
+
+    type = StringTransformedEnum(allowed_values=[NodeType.COMMAND])
+    resources = NestedField(ComponentResourceSchema, unknown=INCLUDE)
+    distribution = UnionField(
+        [
+            NestedField(MPIDistributionSchema, unknown=INCLUDE),
+            NestedField(TensorFlowDistributionSchema, unknown=INCLUDE),
+            NestedField(PyTorchDistributionSchema, unknown=INCLUDE),
+            ExperimentalField(NestedField(RayDistributionSchema, unknown=INCLUDE)),
+        ],
+        metadata={"description": "Provides the configuration for a distributed run."},
+    )
+    # primitive output is only supported for command component & pipeline component
+    outputs = fields.Dict(
+        keys=fields.Str(),
+        values=UnionField(
+            [
+                NestedField(OutputPortSchema),
+                NestedField(PrimitiveOutputSchema, unknown=INCLUDE),
+            ]
+        ),
+    )
+    properties = fields.Dict(keys=fields.Str(), values=fields.Raw())
+
+    # Note: AzureDevopsArtifactsSchema only available when private preview flag opened before init of command component
+    # schema class.
+    if is_private_preview_enabled():
+        additional_includes = fields.List(UnionField([fields.Str(), NestedField(AzureDevopsArtifactsSchema)]))
+    else:
+        additional_includes = fields.List(fields.Str())
+
+    @post_dump
+    def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
+        # remove empty properties to keep the component spec unchanged
+        if not component_schema_dict.get("properties"):
+            component_schema_dict.pop("properties", None)
+        if (
+            component_schema_dict.get("additional_includes") is not None
+            and len(component_schema_dict["additional_includes"]) == 0
+        ):
+            component_schema_dict.pop("additional_includes")
+        return component_schema_dict
+
+
+class RestCommandComponentSchema(CommandComponentSchema):
+    """When component load from rest, won't validate on name since there might be existing component with invalid
+    name."""
+
+    name = fields.Str(required=True)
+
+
+class AnonymousCommandComponentSchema(AnonymousAssetSchema, CommandComponentSchema):
+    """Anonymous command component schema.
+
+    Note inheritance follows order: AnonymousAssetSchema, CommandComponentSchema because we need name and version to be
+    dump_only(marshmallow collects fields follows method resolution order).
+    """
+
+    @post_load
+    def make(self, data, **kwargs):
+        from azure.ai.ml.entities import CommandComponent
+
+        # Inline component will have source=YAML.JOB
+        # As we only regard full separate component file as YAML.COMPONENT
+        return CommandComponent(
+            base_path=self.context[BASE_PATH_CONTEXT_KEY],
+            _source=ComponentSource.YAML_JOB,
+            **data,
+        )
+
+
+class ComponentFileRefField(FileRefField):
+    def _deserialize(self, value, attr, data, **kwargs):
+        # Get component info from component yaml file.
+        data = super()._deserialize(value, attr, data, **kwargs)
+        component_dict = yaml.safe_load(data)
+        source_path = self.context[BASE_PATH_CONTEXT_KEY] / value
+
+        # Update base_path to parent path of component file.
+        component_schema_context = deepcopy(self.context)
+        component_schema_context[BASE_PATH_CONTEXT_KEY] = source_path.parent
+        component = AnonymousCommandComponentSchema(context=component_schema_context).load(
+            component_dict, unknown=INCLUDE
+        )
+        component._source_path = source_path
+        component._source = ComponentSource.YAML_COMPONENT
+        return component