about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py94
1 files changed, 94 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
new file mode 100644
index 00000000..c1ee3568
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/_utils/utils.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import copy
+import logging
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError
+
+module_logger = logging.getLogger(__name__)
+
+
+class ArmId(str):
+    def __new__(cls, content):
+        validate_arm_str(content)
+        return str.__new__(cls, content)
+
+
+def validate_arm_str(arm_str: Union[ArmId, str]) -> bool:
+    """Validate whether the given string is in fact in the format of an ARM ID.
+
+    :param arm_str: The string to validate.
+    :type arm_str: Either a string (in case of incorrect formatting) or ArmID (in case of correct formatting).
+    :returns: True if the string is correctly formatted, False otherwise.
+    :rtype: bool
+    """
+    reg_str = (
+        r"/subscriptions/[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}?/resourcegroups/.*/providers/[a-z.a-z]*/[a-z]*/.*"
+    )
+    lowered = arm_str.lower()
+    match = re.match(reg_str, lowered)
+    if match and match.group() == lowered:
+        return True
+    raise ValidationError(f"ARM string {arm_str} is not formatted correctly.")
+
+
+def get_subnet_str(vnet_name: str, subnet: str, sub_id: Optional[str] = None, rg: Optional[str] = None) -> str:
+    if vnet_name and not subnet:
+        raise ValidationError("Subnet is required when vnet name is specified.")
+    try:
+        validate_arm_str(subnet)
+        return subnet
+    except ValidationError:
+        return (
+            f"/subscriptions/{sub_id}/resourceGroups/{rg}/"
+            f"providers/Microsoft.Network/virtualNetworks/{vnet_name}/subnets/{subnet}"
+        )
+
+
+def replace_key_in_odict(odict: OrderedDict, old_key: Any, new_key: Any):
+    if not odict or old_key not in odict:
+        return odict
+    return OrderedDict([(new_key, v) if k == old_key else (k, v) for k, v in odict.items()])
+
+
+# This is temporary until deployments(batch/K8S) support registry references
+def exit_if_registry_assets(data: Dict, caller: str) -> None:
+    startswith = "azureml://registries/"
+    if (
+        "environment" in data
+        and data["environment"]
+        and isinstance(data["environment"], str)
+        and data["environment"].startswith(startswith)
+    ):
+        raise ValidationError(f"Registry reference for environments is not supported for {caller}")
+    if "model" in data and data["model"] and isinstance(data["model"], str) and data["model"].startswith(startswith):
+        raise ValidationError(f"Registry reference for models is not supported for {caller}")
+    if (
+        "code_configuration" in data
+        and data["code_configuration"].code
+        and isinstance(data["code_configuration"].code, str)
+        and data["code_configuration"].code.startswith(startswith)
+    ):
+        raise ValidationError(f"Registry reference for code_configuration.code is not supported for {caller}")
+
+
+def _resolve_group_inputs_for_component(component, **kwargs):  # pylint: disable=unused-argument
+    # Try resolve object's inputs & outputs and return a resolved new object
+    from azure.ai.ml.entities._inputs_outputs import GroupInput
+
+    result = copy.copy(component)
+
+    flatten_inputs = {}
+    for key, val in result.inputs.items():
+        if isinstance(val, GroupInput):
+            flatten_inputs.update(val.flatten(group_parameter_name=key))
+            continue
+        flatten_inputs[key] = val
+
+    # Flatten group inputs
+    result._inputs = flatten_inputs  # pylint: disable=protected-access
+    return result