about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py3
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py37
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py232
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py21
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py113
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py75
6 files changed, 481 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py
new file mode 100644
index 00000000..d540fd20
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/__init__.py
@@ -0,0 +1,3 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py
new file mode 100644
index 00000000..2dddf02b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/command.py
@@ -0,0 +1,37 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from marshmallow import fields
+
+from ..._schema import NestedField
+from ..._schema.core.fields import DumpableEnumField, EnvironmentField
+from ..._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema
+from ..._schema.job.job_limits import CommandJobLimitsSchema
+from .._schema.node import InternalBaseNodeSchema, NodeType
+
+
+class CommandSchema(InternalBaseNodeSchema, ParameterizedCommandSchema):
+    class Meta:
+        exclude = ["code", "distribution"]  # internal command doesn't have code & distribution
+
+    environment = EnvironmentField()
+    type = DumpableEnumField(allowed_values=[NodeType.COMMAND])
+    limits = NestedField(CommandJobLimitsSchema)
+
+
+class DistributedSchema(CommandSchema):
+    class Meta:
+        exclude = ["code"]  # need to enable distribution comparing to CommandSchema
+
+    type = DumpableEnumField(allowed_values=[NodeType.DISTRIBUTED])
+
+
+class ParallelSchema(InternalBaseNodeSchema, ParameterizedParallelSchema):
+    class Meta:
+        # partition_keys can still be used with unknown warning, but need to do dump before setting
+        exclude = ["task", "input_data", "mini_batch_error_threshold", "partition_keys"]
+
+    type = DumpableEnumField(allowed_values=[NodeType.PARALLEL])
+    compute = fields.Str()
+    environment = fields.Str()
+    limits = NestedField(CommandJobLimitsSchema)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
new file mode 100644
index 00000000..11d4bb56
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/component.py
@@ -0,0 +1,232 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import os.path
+
+import pydash
+from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
+
+from ..._schema import NestedField, StringTransformedEnum, UnionField
+from ..._schema.component.component import ComponentSchema
+from ..._schema.core.fields import ArmVersionedStr, CodeField, EnvironmentField, RegistryStr
+from ..._schema.job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
+from ..._utils._arm_id_utils import parse_name_label
+from ..._utils.utils import get_valid_dot_keys_with_wildcard
+from ...constants._common import (
+    LABELLED_RESOURCE_NAME,
+    SOURCE_PATH_CONTEXT_KEY,
+    AzureMLResourceType,
+    DefaultOpenEncoding,
+)
+from ...constants._component import NodeType as PublicNodeType
+from .._utils import yaml_safe_load_with_base_resolver
+from .environment import InternalEnvironmentSchema
+from .input_output import (
+    InternalEnumParameterSchema,
+    InternalInputPortSchema,
+    InternalOutputPortSchema,
+    InternalParameterSchema,
+    InternalPrimitiveOutputSchema,
+    InternalSparkParameterSchema,
+)
+
+
+class NodeType:
+    COMMAND = "CommandComponent"
+    DATA_TRANSFER = "DataTransferComponent"
+    DISTRIBUTED = "DistributedComponent"
+    HDI = "HDInsightComponent"
+    SCOPE_V2 = "scope"
+    HDI_V2 = "hdinsight"
+    HEMERA_V2 = "hemera"
+    STARLITE_V2 = "starlite"
+    AE365EXEPOOL_V2 = "ae365exepool"
+    AETHER_BRIDGE_V2 = "aetherbridge"
+    PARALLEL = "ParallelComponent"
+    SCOPE = "ScopeComponent"
+    STARLITE = "StarliteComponent"
+    SWEEP = "SweepComponent"
+    PIPELINE = "PipelineComponent"
+    HEMERA = "HemeraComponent"
+    AE365EXEPOOL = "AE365ExePoolComponent"
+    IPP = "IntellectualPropertyProtectedComponent"
+    # internal spake component got a type value conflict with spark component
+    # this enum is used to identify its create_function in factories
+    SPARK = "DummySpark"
+    AETHER_BRIDGE = "AetherBridgeComponent"
+
+    @classmethod
+    def all_values(cls):
+        all_values = []
+        for key, value in vars(cls).items():
+            if not key.startswith("_") and isinstance(value, str):
+                all_values.append(value)
+        return all_values
+
+
+class InternalComponentSchema(ComponentSchema):
+    class Meta:
+        unknown = INCLUDE
+
+    # override name as 1p components allow . in name, which is not allowed in v2 components
+    name = fields.Str()
+
+    # override to allow empty properties
+    tags = fields.Dict(keys=fields.Str())
+
+    # override inputs & outputs to support 1P inputs & outputs, may need to do strict validation later
+    # no need to check io type match since server will do that
+    inputs = fields.Dict(
+        keys=fields.Str(),
+        values=UnionField(
+            [
+                NestedField(InternalParameterSchema),
+                NestedField(InternalEnumParameterSchema),
+                NestedField(InternalInputPortSchema),
+            ]
+        ),
+    )
+    # support primitive output for all internal components for now
+    outputs = fields.Dict(
+        keys=fields.Str(),
+        values=UnionField(
+            [
+                NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
+                NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
+            ]
+        ),
+    )
+
+    # type field is required for registration
+    type = StringTransformedEnum(
+        allowed_values=NodeType.all_values(),
+        casing_transform=lambda x: parse_name_label(x)[0],
+        pass_original=True,
+    )
+
+    # need to resolve as it can be a local field
+    code = CodeField()
+
+    environment = UnionField(
+        [
+            RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+            NestedField(InternalEnvironmentSchema),
+        ]
+    )
+
+    def get_skip_fields(self):
+        return ["properties"]
+
+    def _serialize(self, obj, *, many: bool = False):
+        if many and obj is not None:
+            return super(InternalComponentSchema, self)._serialize(obj, many=many)
+        ret = super(InternalComponentSchema, self)._serialize(obj)
+        for attr_name in obj.__dict__.keys():
+            if (
+                not attr_name.startswith("_")
+                and attr_name not in self.get_skip_fields()
+                and attr_name not in self.dump_fields
+            ):
+                ret[attr_name] = self.get_attribute(obj, attr_name, None)
+        return ret
+
+    # override param_override to ensure that param override happens after reloading the yaml
+    @pre_load
+    def add_param_overrides(self, data, **kwargs):
+        source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None)
+        if isinstance(data, dict) and source_path and os.path.isfile(source_path):
+
+            def should_node_overwritten(_root, _parts):
+                parts = _parts.copy()
+                parts.pop()
+                parts.append("type")
+                _input_type = pydash.get(_root, parts, None)
+                return isinstance(_input_type, str) and _input_type.lower() not in ["boolean"]
+
+            # do override here
+            with open(source_path, "r", encoding=DefaultOpenEncoding.READ) as f:
+                origin_data = yaml_safe_load_with_base_resolver(f)
+                for dot_key_wildcard, condition_func in [
+                    ("version", None),
+                    ("inputs.*.default", should_node_overwritten),
+                    ("inputs.*.enum", should_node_overwritten),
+                ]:
+                    for dot_key in get_valid_dot_keys_with_wildcard(
+                        origin_data, dot_key_wildcard, validate_func=condition_func
+                    ):
+                        pydash.set_(data, dot_key, pydash.get(origin_data, dot_key))
+        return super().add_param_overrides(data, **kwargs)
+
+    @post_dump(pass_original=True)
+    def simplify_input_output_port(self, data, original, **kwargs):  # pylint:disable=unused-argument
+        # remove None in input & output
+        for io_ports in [data["inputs"], data["outputs"]]:
+            for port_name, port_definition in io_ports.items():
+                io_ports[port_name] = dict(filter(lambda item: item[1] is not None, port_definition.items()))
+
+        # hack, to match current serialization match expectation
+        for port_name, port_definition in data["inputs"].items():
+            if "mode" in port_definition:
+                del port_definition["mode"]
+
+        return data
+
+    @post_dump(pass_original=True)
+    def add_back_type_label(self, data, original, **kwargs):  # pylint:disable=unused-argument
+        type_label = original._type_label  # pylint:disable=protected-access
+        if type_label:
+            data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label)
+        return data
+
+
+class InternalSparkComponentSchema(InternalComponentSchema):
+    # type field is required for registration
+    type = StringTransformedEnum(
+        allowed_values=PublicNodeType.SPARK,
+        casing_transform=lambda x: parse_name_label(x)[0].lower(),
+        pass_original=True,
+    )
+
+    # override inputs:
+    # https://componentsdk.azurewebsites.net/components/spark_component.html#differences-with-other-component-types
+    inputs = fields.Dict(
+        keys=fields.Str(),
+        values=UnionField(
+            [
+                NestedField(InternalSparkParameterSchema),
+                NestedField(InternalInputPortSchema),
+            ]
+        ),
+    )
+
+    environment = EnvironmentField(
+        extra_fields=[NestedField(InternalEnvironmentSchema)],
+        allow_none=True,
+    )
+
+    jars = UnionField(
+        [
+            fields.List(fields.Str()),
+            fields.Str(),
+        ],
+    )
+    py_files = UnionField(
+        [
+            fields.List(fields.Str()),
+            fields.Str(),
+        ],
+        data_key="pyFiles",
+        attribute="py_files",
+    )
+
+    entry = UnionField(
+        [NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
+        required=True,
+        metadata={"description": "Entry."},
+    )
+
+    files = fields.List(fields.Str(required=True))
+    archives = fields.List(fields.Str(required=True))
+    conf = fields.Dict(keys=fields.Str(), values=fields.Raw())
+    args = fields.Str(metadata={"description": "Command Line arguments."})
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py
new file mode 100644
index 00000000..f7c20228
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/environment.py
@@ -0,0 +1,21 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields
+
+from ..._schema import PathAwareSchema
+from ..._schema.core.fields import DumpableEnumField, VersionField
+
+
+class InternalEnvironmentSchema(PathAwareSchema):
+    docker = fields.Dict()
+    conda = fields.Dict()
+    os = DumpableEnumField(
+        # add enum instead of use string transformer here to avoid changing the value
+        allowed_values=["Linux", "Windows", "linux", "windows"],
+        required=False,
+    )
+    name = fields.Str()
+    version = VersionField()
+    python = fields.Dict()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py
new file mode 100644
index 00000000..b1fe2188
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/input_output.py
@@ -0,0 +1,113 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import fields, post_dump, post_load
+
+from ..._schema import PatchedSchemaMeta, StringTransformedEnum, UnionField
+from ..._schema.component.input_output import InputPortSchema, ParameterSchema
+from ..._schema.core.fields import DumpableEnumField, PrimitiveValueField
+
+SUPPORTED_INTERNAL_PARAM_TYPES = [
+    "integer",
+    "Integer",
+    "boolean",
+    "Boolean",
+    "string",
+    "String",
+    "float",
+    "Float",
+    "double",
+    "Double",
+]
+
+
+SUPPORTED_INTERNAL_SPARK_PARAM_TYPES = [
+    "integer",
+    "Integer",
+    "boolean",
+    "Boolean",
+    "string",
+    "String",
+    "double",
+    "Double",
+    # remove float and add number
+    "number",
+]
+
+
+class InternalInputPortSchema(InputPortSchema):
+    # skip client-side validate for type enum & support list
+    type = UnionField(
+        [
+            fields.Str(),
+            fields.List(fields.Str()),
+        ],
+        required=True,
+        data_key="type",
+    )
+    is_resource = fields.Bool()
+    datastore_mode = fields.Str()
+
+    @post_dump(pass_original=True)
+    def resolve_list_type(self, data, original_data, **kwargs):  # pylint: disable=unused-argument
+        if isinstance(original_data.type, list):
+            data["type"] = original_data.type
+        return data
+
+
+class InternalOutputPortSchema(metaclass=PatchedSchemaMeta):
+    # skip client-side validate for type enum
+    type = fields.Str(
+        required=True,
+        data_key="type",
+    )
+    description = fields.Str()
+    is_link_mode = fields.Bool()
+    datastore_mode = fields.Str()
+
+
+class InternalPrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
+        required=True,
+    )
+    description = fields.Str()
+
+
+class InternalParameterSchema(ParameterSchema):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
+        required=True,
+        data_key="type",
+    )
+
+
+class InternalSparkParameterSchema(ParameterSchema):
+    type = DumpableEnumField(
+        allowed_values=SUPPORTED_INTERNAL_SPARK_PARAM_TYPES,
+        required=True,
+        data_key="type",
+    )
+
+
+class InternalEnumParameterSchema(ParameterSchema):
+    type = StringTransformedEnum(
+        allowed_values=["enum"],
+        required=True,
+        data_key="type",
+    )
+    default = PrimitiveValueField()
+    enum = fields.List(
+        PrimitiveValueField(),
+        required=True,
+    )
+
+    @post_dump
+    @post_load
+    def enum_value_to_string(self, data, **kwargs):  # pylint: disable=unused-argument
+        if "enum" in data:
+            data["enum"] = list(map(str, data["enum"]))
+        if "default" in data and data["default"] is not None:
+            data["default"] = str(data["default"])
+        return data
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
new file mode 100644
index 00000000..6dbadcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_internal/_schema/node.py
@@ -0,0 +1,75 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from marshmallow import INCLUDE, fields, post_load, pre_dump
+
+from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
+from ..._schema.core.fields import DumpableEnumField
+from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
+from ...constants._common import AzureMLResourceType
+from .component import InternalComponentSchema, NodeType
+
+
+class InternalBaseNodeSchema(BaseNodeSchema):
+    class Meta:
+        unknown = INCLUDE
+
+    component = UnionField(
+        [
+            # for registry type assets
+            RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
+            # existing component
+            ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, allow_default_version=True),
+            # inline component or component file reference starting with FILE prefix
+            NestedField(InternalComponentSchema, unknown=INCLUDE),
+        ],
+        required=True,
+    )
+    type = DumpableEnumField(
+        allowed_values=NodeType.all_values(),
+    )
+
+    @post_load
+    def make(self, data, **kwargs):  # pylint: disable=unused-argument
+        from ...entities._builders import parse_inputs_outputs
+
+        # parse inputs/outputs
+        data = parse_inputs_outputs(data)
+
+        # dict to node object
+        from ...entities._job.pipeline._load_component import pipeline_node_factory
+
+        return pipeline_node_factory.load_from_dict(data=data)
+
+    @pre_dump
+    def resolve_inputs_outputs(self, job, **kwargs):  # pylint: disable=unused-argument
+        return _resolve_inputs_outputs(job)
+
+
+class ScopeSchema(InternalBaseNodeSchema):
+    type = DumpableEnumField(allowed_values=[NodeType.SCOPE])
+    adla_account_name = fields.Str(required=True)
+    scope_param = fields.Str()
+    custom_job_name_suffix = fields.Str()
+    priority = fields.Int()
+    auto_token = fields.Int()
+    tokens = fields.Int()
+    vcp = fields.Float()
+
+
+class HDInsightSchema(InternalBaseNodeSchema):
+    type = DumpableEnumField(allowed_values=[NodeType.HDI])
+
+    compute_name = fields.Str()
+    queue = fields.Str()
+    driver_memory = fields.Str()
+    driver_cores = fields.Int()
+    executor_memory = fields.Str()
+    executor_cores = fields.Int()
+    number_executors = fields.Int()
+    conf = UnionField(
+        # dictionary or json string
+        union_fields=[fields.Dict(keys=fields.Str()), fields.Str()],
+    )
+    hdinsight_spark_job_name = fields.Str()