about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py5
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py88
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py94
3 files changed, 187 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
new file mode 100644
index 00000000..29a4fcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/__init__.py
@@ -0,0 +1,5 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)  # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
new file mode 100644
index 00000000..611c80a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/data_binding_expression.py
@@ -0,0 +1,88 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+from typing import Union
+
+from marshmallow import Schema, fields
+
+from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
+from azure.ai.ml._schema.core.schema import PathAwareSchema
+
+DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
+
+
+def _is_literal(field):
+    return not isinstance(field, (NestedField, fields.List, fields.Dict, UnionField))
+
+
+def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
+    if hasattr(field, DATA_BINDING_SUPPORTED_KEY) and getattr(field, DATA_BINDING_SUPPORTED_KEY):
+        return field
+    data_binding_field = DataBindingStr()
+    if isinstance(field, UnionField):
+        for field_obj in field.union_fields:
+            if not _is_literal(field_obj):
+                _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+        field.insert_union_field(data_binding_field)
+    elif isinstance(field, fields.Dict):
+        # handle dict, dict value can be None
+        if field.value_field is not None:
+            field.value_field = _add_data_binding_to_field(field.value_field, attrs_to_skip, schema_stack=schema_stack)
+    elif isinstance(field, fields.List):
+        # handle list
+        field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
+    elif isinstance(field, ExperimentalField):
+        field = ExperimentalField(
+            _add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
+            data_key=field.data_key,
+            attribute=field.attribute,
+            dump_only=field.dump_only,
+            required=field.required,
+            allow_none=field.allow_none,
+        )
+    elif isinstance(field, NestedField):
+        # handle nested field
+        support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)
+    else:
+        # change basic fields to union
+        field = UnionField(
+            [data_binding_field, field],
+            data_key=field.data_key,
+            attribute=field.attribute,
+            dump_only=field.dump_only,
+            required=field.required,
+            allow_none=field.allow_none,
+        )
+
+    setattr(field, DATA_BINDING_SUPPORTED_KEY, True)
+    return field
+
+
+# pylint: disable-next=docstring-missing-param
+def support_data_binding_expression_for_fields(  # pylint: disable=name-too-long
+    schema: Union[PathAwareSchema, Schema], attrs_to_skip=None, schema_stack=None
+):
+    """Update fields inside schema to support data binding string.
+
+    Only first layer of recursive schema is supported now.
+    """
+    if hasattr(schema, DATA_BINDING_SUPPORTED_KEY) and getattr(schema, DATA_BINDING_SUPPORTED_KEY):
+        return
+
+    setattr(schema, DATA_BINDING_SUPPORTED_KEY, True)
+
+    if attrs_to_skip is None:
+        attrs_to_skip = []
+    if schema_stack is None:
+        schema_stack = []
+    schema_type_name = type(schema).__name__
+    if schema_type_name in schema_stack:
+        return
+    schema_stack.append(schema_type_name)
+    for attr, field_obj in schema.load_fields.items():
+        if attr not in attrs_to_skip:
+            schema.load_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+    for attr, field_obj in schema.dump_fields.items():
+        if attr not in attrs_to_skip:
+            schema.dump_fields[attr] = _add_data_binding_to_field(field_obj, attrs_to_skip, schema_stack=schema_stack)
+    schema_stack.pop()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
new file mode 100644
index 00000000..c1ee3568
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import logging
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError
+
+module_logger = logging.getLogger(__name__)
+
+
+class ArmId(str):
+    def __new__(cls, content):
+        validate_arm_str(content)
+        return str.__new__(cls, content)
+
+
+def validate_arm_str(arm_str: Union[ArmId, str]) -> bool:
+    """Validate whether the given string is in fact in the format of an ARM ID.
+
+    :param arm_str: The string to validate.
+    :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting).
+    :returns: True if the string is correctly formatted, False otherwise.
+    :rtype: bool
+    """
+    reg_str = (
+        r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*"
+    )
+    lowered = arm_str.lower()
+    match = re.match(reg_str, lowered)
+    if match and match.group() == lowered:
+        return True
+    raise ValidationError(f"ARM string {arm_str} is not formatted correctly.")
+
+
+def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str:
+    if vnet_name and not subnet:
+        raise ValidationError("Subnet is required when vnet name is specified.")
+    try:
+        validate_arm_str(subnet)
+        return subnet
+    except ValidationError:
+        return (
+            f"/subscriptions/{sub_id}/resourceGroups/{rg}/"
+            f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}"
+        )
+
+
+def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any):
+    if not odict or old_key not in odict:
+        return odict
+    return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()])
+
+
+# This is temporary until deployments(batch/K8S) support registry references
+def exit_if_registry_assets(data: Dict, caller: str) -> None:
+    startswith = "azureml://registries/"
+    if (
+        "environment" in data
+        and data["environment"]
+        and isinstance(data["environment"], str)
+        and data["environment"].startswith(startswith)
+    ):
+        raise ValidationError(f"Registry reference for environments is not supported for {caller}")
+    if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith):
+        raise ValidationError(f"Registry reference for models is not supported for {caller}")
+    if (
+        "code_configuration" in data
+        and data["code_configuration"].code
+        and isinstance(data["code_configuration"].code, str)
+        and data["code_configuration"].code.startswith(startswith)
+    ):
+        raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}")
+
+
+def _resolve_group_inputs_for_component(component, **kwargs):  # pylint: disable=unused-argument
+    # Try resolve object's inputs & outputs and return a resolved new object
+    from azure.ai.ml.entities._inputs_outputs import GroupInput
+
+    result = copy.copy(component)
+
+    flatten_inputs = {}
+    for key, val in result.inputs.items():
+        if isinstance(val, GroupInput):
+            flatten_inputs.update(val.flatten(group_parameter_name=key))
+            continue
+        flatten_inputs[key] = val
+
+    # Flatten group inputs
+    result._inputs = flatten_inputs  # pylint: disable=protected-access
+    return result