aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
new file mode 100644
index 00000000..42b0bee6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
@@ -0,0 +1,156 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import functools
+import inspect
+import logging
+import sys
+from typing import Callable, Type, TypeVar, Union
+
+from typing_extensions import ParamSpec
+
+from azure.ai.ml.constants._common import (
+ DOCSTRING_DEFAULT_INDENTATION,
+ DOCSTRING_TEMPLATE,
+ EXPERIMENTAL_CLASS_MESSAGE,
+ EXPERIMENTAL_LINK_MESSAGE,
+ EXPERIMENTAL_METHOD_MESSAGE,
+)
+
+_warning_cache = set()
+module_logger = logging.getLogger(__name__)
+
+TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
+P = ParamSpec("P")
+T = TypeVar("T")
+
+
+def experimental(wrapped: TExperimental) -> TExperimental:
+ """Add experimental tag to a class or a method.
+
+ :param wrapped: Either a Class or Function to mark as experimental
+ :type wrapped: TExperimental
+ :return: The wrapped class or method
+ :rtype: TExperimental
+ """
+ if inspect.isclass(wrapped):
+ return _add_class_docstring(wrapped)
+ if inspect.isfunction(wrapped):
+ return _add_method_docstring(wrapped)
+ return wrapped
+
+
+def _add_class_docstring(cls: Type[T]) -> Type[T]:
+ """Add experimental tag to the class doc string.
+
+ :return: The updated class
+ :rtype: Type[T]
+ """
+
+ P2 = ParamSpec("P2")
+
+ def _add_class_warning(func: Callable[P2, None]) -> Callable[P2, None]:
+ """Add warning message for class __init__.
+
+ :param func: The original __init__ function
+ :type func: Callable[P2, None]
+ :return: Updated __init__
+ :rtype: Callable[P2, None]
+ """
+
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ message = "Class {0}: {1} {2}".format(cls.__name__, EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+ if not _should_skip_warning() and not _is_warning_cached(message):
+ module_logger.warning(message)
+ return func(*args, **kwargs)
+
+ return wrapped
+
+ doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+ if cls.__doc__:
+ cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
+ else:
+ cls.__doc__ = doc_string + ">"
+ cls.__init__ = _add_class_warning(cls.__init__)
+ return cls
+
+
+def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
+ """Add experimental tag to the method doc string.
+
+ :param func: The function to update
+ :type func: Callable[P, T]
+ :return: A wrapped method marked as experimental
+ :rtype: Callable[P,T]
+ """
+ doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+ if func.__doc__:
+ func.__doc__ = _add_note_to_docstring(func.__doc__, doc_string)
+ else:
+ # '>' is required. Otherwise the note section can't be generated
+ func.__doc__ = doc_string + ">"
+
+ @functools.wraps(func)
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
+ message = "Method {0}: {1} {2}".format(func.__name__, EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
+ if not _should_skip_warning() and not _is_warning_cached(message):
+ module_logger.warning(message)
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+def _add_note_to_docstring(doc_string: str, note: str) -> str:
+ """Adds experimental note to docstring at the top and correctly indents original docstring.
+
+ :param doc_string: The docstring
+ :type doc_string: str
+ :param note: The note to add to the docstring
+ :type note: str
+ :return: Updated docstring
+ :rtype: str
+ """
+ indent = _get_indentation_size(doc_string)
+ doc_string = doc_string.rjust(len(doc_string) + indent)
+ return note + doc_string
+
+
+def _get_indentation_size(doc_string: str) -> int:
+ """Finds the minimum indentation of all non-blank lines after the first line.
+
+ :param doc_string: The docstring
+ :type doc_string: str
+ :return: Minimum number of indentation of the docstring
+ :rtype: int
+ """
+ lines = doc_string.expandtabs().splitlines()
+ indent = sys.maxsize
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ return indent if indent < sys.maxsize else DOCSTRING_DEFAULT_INDENTATION
+
+
+def _should_skip_warning():
+ skip_warning_msg = False
+
+ # Cases where we want to suppress the warning:
+ # 1. When converting from REST object to SDK object
+ for frame in inspect.stack():
+ if frame.function == "_from_rest_object":
+ skip_warning_msg = True
+ break
+
+ return skip_warning_msg
+
+
+def _is_warning_cached(warning_msg):
+ # use cache to make sure we only print same warning message once under same session
+ # this prevents duplicated warnings got printed when user does a loop call on a method or a class
+ if warning_msg in _warning_cache:
+ return True
+ _warning_cache.add(warning_msg)
+ return False