about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
new file mode 100644
index 00000000..42b0bee6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
@@ -0,0 +1,156 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import functools
+import inspect
+import logging
+import sys
+from typing import Callable, Type, TypeVar, Union
+
+from typing_extensions import ParamSpec
+
+from azure.ai.ml.constants._common import (
+    DOCSTRING_DEFAULT_INDENTATION,
+    DOCSTRING_TEMPLATE,
+    EXPERIMENTAL_CLASS_MESSAGE,
+    EXPERIMENTAL_LINK_MESSAGE,
+    EXPERIMENTAL_METHOD_MESSAGE,
+)
+
+_warning_cache = set()
+module_logger = logging.getLogger(__name__)
+
+TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
+P = ParamSpec("P")
+T = TypeVar("T")
+
+
+def experimental(wrapped: TExperimental) -> TExperimental:
+    """Add experimental tag to a class or a method.
+
+    :param wrapped: Either a Class or Function to mark as experimental
+    :type wrapped: TExperimental
+    :return: The wrapped class or method
+    :rtype: TExperimental
+    """
+    if inspect.isclass(wrapped):
+        return _add_class_docstring(wrapped)
+    if inspect.isfunction(wrapped):
+        return _add_method_docstring(wrapped)
+    return wrapped
+
+
+def _add_class_docstring(cls: Type[T]) -> Type[T]:
+    """Add experimental tag to the class doc string.
+
+    :return: The updated class
+    :rtype: Type[T]
+    """
+
+    P2 = ParamSpec("P2")
+
+    def _add_class_warning(func: Callable[P2, None]) -> Callable[P2, None]:
+        """Add warning message for class __init__.
+
+        :param func: The original __init__ function
+        :type func: Callable[P2, None]
+        :return: Updated __init__
+        :rtype: Callable[P2, None]
+        """
+
+        @functools.wraps(func)
+        def wrapped(*args, **kwargs):
+            message = "Class {0}: {1} {2}".format(cls.__name__, EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+            if not _should_skip_warning() and not _is_warning_cached(message):
+                module_logger.warning(message)
+            return func(*args, **kwargs)
+
+        return wrapped
+
+    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+    if cls.__doc__:
+        cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
+    else:
+        cls.__doc__ = doc_string + ">"
+    cls.__init__ = _add_class_warning(cls.__init__)
+    return cls
+
+
+def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
+    """Add experimental tag to the method doc string.
+
+    :param func: The function to update
+    :type func: Callable[P, T]
+    :return: A wrapped method marked as experimental
+    :rtype: Callable[P,T]
+    """
+    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+    if func.__doc__:
+        func.__doc__ = _add_note_to_docstring(func.__doc__, doc_string)
+    else:
+        # '>' is required. Otherwise the note section can't be generated
+        func.__doc__ = doc_string + ">"
+
+    @functools.wraps(func)
+    def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
+        message = "Method {0}: {1} {2}".format(func.__name__, EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+        if not _should_skip_warning() and not _is_warning_cached(message):
+            module_logger.warning(message)
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _add_note_to_docstring(doc_string: str, note: str) -> str:
+    """Adds experimental note to docstring at the top and correctly indents original docstring.
+
+    :param doc_string: The docstring
+    :type doc_string: str
+    :param note: The note to add to the docstring
+    :type note: str
+    :return: Updated docstring
+    :rtype: str
+    """
+    indent = _get_indentation_size(doc_string)
+    doc_string = doc_string.rjust(len(doc_string) + indent)
+    return note + doc_string
+
+
+def _get_indentation_size(doc_string: str) -> int:
+    """Finds the minimum indentation of all non-blank lines after the first line.
+
+    :param doc_string: The docstring
+    :type doc_string: str
+    :return: Minimum number of indentation of the docstring
+    :rtype: int
+    """
+    lines = doc_string.expandtabs().splitlines()
+    indent = sys.maxsize
+    for line in lines[1:]:
+        stripped = line.lstrip()
+        if stripped:
+            indent = min(indent, len(line) - len(stripped))
+    return indent if indent < sys.maxsize else DOCSTRING_DEFAULT_INDENTATION
+
+
+def _should_skip_warning():
+    skip_warning_msg = False
+
+    # Cases where we want to suppress the warning:
+    # 1. When converting from REST object to SDK object
+    for frame in inspect.stack():
+        if frame.function == "_from_rest_object":
+            skip_warning_msg = True
+            break
+
+    return skip_warning_msg
+
+
+def _is_warning_cached(warning_msg):
+    # use cache to make sure we only print same warning message once under same session
+    # this prevents duplicated warnings got printed when user does a loop call on a method or a class
+    if warning_msg in _warning_cache:
+        return True
+    _warning_cache.add(warning_msg)
+    return False