about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
new file mode 100644
index 00000000..08521d7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_job/automl/utils.py
@@ -0,0 +1,47 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import TYPE_CHECKING, Dict, Type, Union
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+if TYPE_CHECKING:
+    from azure.ai.ml.entities._job.automl.image.image_classification_search_space import ImageClassificationSearchSpace
+    from azure.ai.ml.entities._job.automl.image.image_object_detection_search_space import (
+        ImageObjectDetectionSearchSpace,
+    )
+    from azure.ai.ml.entities._job.automl.nlp.nlp_search_space import NlpSearchSpace
+    from azure.ai.ml.entities._job.automl.search_space import SearchSpace
+
+
+def cast_to_specific_search_space(
+    input: Union[Dict, "SearchSpace"],  # pylint: disable=redefined-builtin
+    class_name: Union[
+        Type["ImageClassificationSearchSpace"], Type["ImageObjectDetectionSearchSpace"], Type["NlpSearchSpace"]
+    ],
+    task_type: str,
+) -> Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"]:
+    def validate_searchspace_args(input_dict: dict) -> None:
+        searchspace = class_name()
+        for key in input_dict:
+            if not hasattr(searchspace, key):
+                msg = f"Received unsupported search space parameter for {task_type} Job."
+                raise ValidationException(
+                    message=msg,
+                    no_personal_data_message=msg,
+                    target=ErrorTarget.AUTOML,
+                    error_category=ErrorCategory.USER_ERROR,
+                )
+
+    if isinstance(input, dict):
+        validate_searchspace_args(input)
+        specific_search_space = class_name(**input)
+    else:
+        validate_searchspace_args(input.__dict__)
+        specific_search_space = class_name._from_search_space_object(input)  # pylint: disable=protected-access
+
+    res: Union["ImageClassificationSearchSpace", "ImageObjectDetectionSearchSpace", "NlpSearchSpace"] = (
+        specific_search_space
+    )
+    return res