aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import functools
import inspect
import logging
import sys
from typing import Callable, Type, TypeVar, Union

from typing_extensions import ParamSpec

from azure.ai.ml.constants._common import (
    DOCSTRING_DEFAULT_INDENTATION,
    DOCSTRING_TEMPLATE,
    EXPERIMENTAL_CLASS_MESSAGE,
    EXPERIMENTAL_LINK_MESSAGE,
    EXPERIMENTAL_METHOD_MESSAGE,
)

_warning_cache = set()
module_logger = logging.getLogger(__name__)

TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
P = ParamSpec("P")
T = TypeVar("T")


def experimental(wrapped: TExperimental) -> TExperimental:
    """Add experimental tag to a class or a method.

    :param wrapped: Either a Class or Function to mark as experimental
    :type wrapped: TExperimental
    :return: The wrapped class or method
    :rtype: TExperimental
    """
    if inspect.isclass(wrapped):
        return _add_class_docstring(wrapped)
    if inspect.isfunction(wrapped):
        return _add_method_docstring(wrapped)
    return wrapped


def _add_class_docstring(cls: Type[T]) -> Type[T]:
    """Add experimental tag to the class doc string.

    :return: The updated class
    :rtype: Type[T]
    """

    P2 = ParamSpec("P2")

    def _add_class_warning(func: Callable[P2, None]) -> Callable[P2, None]:
        """Add warning message for class __init__.

        :param func: The original __init__ function
        :type func: Callable[P2, None]
        :return: Updated __init__
        :rtype: Callable[P2, None]
        """

        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            message = "Class {0}: {1} {2}".format(cls.__name__, EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
            if not _should_skip_warning() and not _is_warning_cached(message):
                module_logger.warning(message)
            return func(*args, **kwargs)

        return wrapped

    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
    if cls.__doc__:
        cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
    else:
        cls.__doc__ = doc_string + ">"
    cls.__init__ = _add_class_warning(cls.__init__)
    return cls


def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
    """Add experimental tag to the method doc string.

    :param func: The function to update
    :type func: Callable[P, T]
    :return: A wrapped method marked as experimental
    :rtype: Callable[P,T]
    """
    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
    if func.__doc__:
        func.__doc__ = _add_note_to_docstring(func.__doc__, doc_string)
    else:
        # '>' is required. Otherwise the note section can't be generated
        func.__doc__ = doc_string + ">"

    @functools.wraps(func)
    def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
        message = "Method {0}: {1} {2}".format(func.__name__, EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
        if not _should_skip_warning() and not _is_warning_cached(message):
            module_logger.warning(message)
        return func(*args, **kwargs)

    return wrapped


def _add_note_to_docstring(doc_string: str, note: str) -> str:
    """Adds experimental note to docstring at the top and correctly indents original docstring.

    :param doc_string: The docstring
    :type doc_string: str
    :param note: The note to add to the docstring
    :type note: str
    :return: Updated docstring
    :rtype: str
    """
    indent = _get_indentation_size(doc_string)
    doc_string = doc_string.rjust(len(doc_string) + indent)
    return note + doc_string


def _get_indentation_size(doc_string: str) -> int:
    """Finds the minimum indentation of all non-blank lines after the first line.

    :param doc_string: The docstring
    :type doc_string: str
    :return: Minimum number of indentation of the docstring
    :rtype: int
    """
    lines = doc_string.expandtabs().splitlines()
    indent = sys.maxsize
    for line in lines[1:]:
        stripped = line.lstrip()
        if stripped:
            indent = min(indent, len(line) - len(stripped))
    return indent if indent < sys.maxsize else DOCSTRING_DEFAULT_INDENTATION


def _should_skip_warning():
    skip_warning_msg = False

    # Cases where we want to suppress the warning:
    # 1. When converting from REST object to SDK object
    for frame in inspect.stack():
        if frame.function == "_from_rest_object":
            skip_warning_msg = True
            break

    return skip_warning_msg


def _is_warning_cached(warning_msg):
    # use cache to make sure we only print same warning message once under same session
    # this prevents duplicated warnings got printed when user does a loop call on a method or a class
    if warning_msg in _warning_cache:
        return True
    _warning_cache.add(warning_msg)
    return False