# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import functools
import inspect
import logging
import sys
from typing import Callable, Type, TypeVar, Union
from typing_extensions import ParamSpec
from azure.ai.ml.constants._common import (
DOCSTRING_DEFAULT_INDENTATION,
DOCSTRING_TEMPLATE,
EXPERIMENTAL_CLASS_MESSAGE,
EXPERIMENTAL_LINK_MESSAGE,
EXPERIMENTAL_METHOD_MESSAGE,
)
_warning_cache = set()
module_logger = logging.getLogger(__name__)
TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
P = ParamSpec("P")
T = TypeVar("T")
def experimental(wrapped: TExperimental) -> TExperimental:
"""Add experimental tag to a class or a method.
:param wrapped: Either a Class or Function to mark as experimental
:type wrapped: TExperimental
:return: The wrapped class or method
:rtype: TExperimental
"""
if inspect.isclass(wrapped):
return _add_class_docstring(wrapped)
if inspect.isfunction(wrapped):
return _add_method_docstring(wrapped)
return wrapped
def _add_class_docstring(cls: Type[T]) -> Type[T]:
"""Add experimental tag to the class doc string.
:return: The updated class
:rtype: Type[T]
"""
P2 = ParamSpec("P2")
def _add_class_warning(func: Callable[P2, None]) -> Callable[P2, None]:
"""Add warning message for class __init__.
:param func: The original __init__ function
:type func: Callable[P2, None]
:return: Updated __init__
:rtype: Callable[P2, None]
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
message = "Class {0}: {1} {2}".format(cls.__name__, EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
if not _should_skip_warning() and not _is_warning_cached(message):
module_logger.warning(message)
return func(*args, **kwargs)
return wrapped
doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
if cls.__doc__:
cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
else:
cls.__doc__ = doc_string + ">"
cls.__init__ = _add_class_warning(cls.__init__)
return cls
def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
"""Add experimental tag to the method doc string.
:param func: The function to update
:type func: Callable[P, T]
:return: A wrapped method marked as experimental
:rtype: Callable[P,T]
"""
doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
if func.__doc__:
func.__doc__ = _add_note_to_docstring(func.__doc__, doc_string)
else:
# '>' is required. Otherwise the note section can't be generated
func.__doc__ = doc_string + ">"
@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
message = "Method {0}: {1} {2}".format(func.__name__, EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
if not _should_skip_warning() and not _is_warning_cached(message):
module_logger.warning(message)
return func(*args, **kwargs)
return wrapped
def _add_note_to_docstring(doc_string: str, note: str) -> str:
"""Adds experimental note to docstring at the top and correctly indents original docstring.
:param doc_string: The docstring
:type doc_string: str
:param note: The note to add to the docstring
:type note: str
:return: Updated docstring
:rtype: str
"""
indent = _get_indentation_size(doc_string)
doc_string = doc_string.rjust(len(doc_string) + indent)
return note + doc_string
def _get_indentation_size(doc_string: str) -> int:
"""Finds the minimum indentation of all non-blank lines after the first line.
:param doc_string: The docstring
:type doc_string: str
:return: Minimum number of indentation of the docstring
:rtype: int
"""
lines = doc_string.expandtabs().splitlines()
indent = sys.maxsize
for line in lines[1:]:
stripped = line.lstrip()
if stripped:
indent = min(indent, len(line) - len(stripped))
return indent if indent < sys.maxsize else DOCSTRING_DEFAULT_INDENTATION
def _should_skip_warning():
skip_warning_msg = False
# Cases where we want to suppress the warning:
# 1. When converting from REST object to SDK object
for frame in inspect.stack():
if frame.function == "_from_rest_object":
skip_warning_msg = True
break
return skip_warning_msg
def _is_warning_cached(warning_msg):
# use cache to make sure we only print same warning message once under same session
# this prevents duplicated warnings got printed when user does a loop call on a method or a class
if warning_msg in _warning_cache:
return True
_warning_cache.add(warning_msg)
return False