about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
new file mode 100644
index 00000000..5772a607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/component/component.py
@@ -0,0 +1,143 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from pathlib import Path
+
+from marshmallow import ValidationError, fields, post_dump, pre_dump, pre_load
+from marshmallow.fields import Field
+
+from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
+from azure.ai.ml._schema.core.fields import (
+    ArmVersionedStr,
+    ExperimentalField,
+    NestedField,
+    PythonFuncNameStr,
+    UnionField,
+)
+from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
+from azure.ai.ml._utils.utils import is_private_preview_enabled, load_yaml
+from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AzureMLResourceType
+
+from .._utils.utils import _resolve_group_inputs_for_component
+from ..assets.asset import AssetSchema
+from ..core.fields import RegistryStr
+
+
+class ComponentNameStr(PythonFuncNameStr):
+    def _get_field_name(self):
+        return "Component"
+
+
+class ComponentYamlRefField(Field):
+    """Allows you to nest a :class:`Schema <marshmallow.Schema>`
+    inside a yaml ref field.
+    """
+
+    def _jsonschema_type_mapping(self):
+        schema = {"type": "string"}
+        if self.name is not None:
+            schema["title"] = self.name
+        if self.dump_only:
+            schema["readonly"] = True
+        return schema
+
+    def _deserialize(self, value, attr, data, **kwargs):
+        if not isinstance(value, str):
+            raise ValidationError(f"Nested yaml ref field expected a string but got {type(value)}.")
+
+        base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
+
+        source_path = Path(value)
+        # raise if the string is not a valid path, like "azureml:xxx"
+        try:
+            source_path.resolve()
+        except OSError as ex:
+            raise ValidationError(f"Nested file ref field expected a local path but got {value}.") from ex
+
+        if not source_path.is_absolute():
+            source_path = base_path / source_path
+
+        if not source_path.is_file():
+            raise ValidationError(
+                f"Nested yaml ref field expected a local path but can't find {value} based on {base_path.as_posix()}."
+            )
+
+        loaded_value = load_yaml(source_path)
+
+        # local import to avoid circular import
+        from azure.ai.ml.entities import Component
+
+        component = Component._load(data=loaded_value, yaml_path=source_path)  # pylint: disable=protected-access
+        return component
+
+    def _serialize(self, value, attr, obj, **kwargs):
+        raise ValidationError("Serialize on RefField is not supported.")
+
+
+class ComponentSchema(AssetSchema):
+    schema = fields.Str(data_key="$schema", attribute="_schema")
+    name = ComponentNameStr(required=True)
+    id = UnionField(
+        [
+            RegistryStr(dump_only=True),
+            ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT, dump_only=True),
+        ]
+    )
+    display_name = fields.Str()
+    description = fields.Str()
+    tags = fields.Dict(keys=fields.Str(), values=fields.Str())
+    is_deterministic = fields.Bool()
+    inputs = fields.Dict(
+        keys=fields.Str(),
+        values=UnionField(
+            [
+                NestedField(ParameterSchema),
+                NestedField(InputPortSchema),
+            ]
+        ),
+    )
+    outputs = fields.Dict(
+        keys=fields.Str(),
+        values=NestedField(OutputPortSchema),
+    )
+    # hide in private preview
+    if is_private_preview_enabled():
+        intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema))
+
+    def __init__(self, *args, **kwargs):
+        # Remove schema_ignored to enable serialize and deserialize schema.
+        self._declared_fields.pop("schema_ignored", None)
+        super().__init__(*args, **kwargs)
+
+    @pre_load
+    def convert_version_to_str(self, data, **kwargs):  # pylint: disable=unused-argument
+        if isinstance(data, dict) and data.get("version", None):
+            data["version"] = str(data["version"])
+        return data
+
+    @pre_dump
+    def add_private_fields_to_dump(self, data, **kwargs):  # pylint: disable=unused-argument
+        # The ipp field is set on the component object as "_intellectual_property".
+        # We need to set it as "intellectual_property" before dumping so that Marshmallow
+        # can pick up the field correctly on dump and show it back to the user.
+        ipp_field = data._intellectual_property  # pylint: disable=protected-access
+        if ipp_field:
+            setattr(data, "intellectual_property", ipp_field)
+        return data
+
+    @post_dump
+    def convert_input_value_to_str(self, data, **kwargs):  # pylint:disable=unused-argument
+        if isinstance(data, dict) and data.get("inputs", None):
+            input_dict = data["inputs"]
+            for input_value in input_dict.values():
+                input_type = input_value.get("type", None)
+                if isinstance(input_type, str) and input_type.lower() == "float":
+                    # Convert number to string to avoid precision issue
+                    for key in ["default", "min", "max"]:
+                        if input_value.get(key, None) is not None:
+                            input_value[key] = str(input_value[key])
+        return data
+
+    @pre_dump
+    def flatten_group_inputs(self, data, **kwargs):  # pylint: disable=unused-argument
+        return _resolve_group_inputs_for_component(data)