aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_experimental.py
blob: 42b0bee614ff729971357fcf032a190f0c575cf5 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import functools
import inspect
import logging
import sys
from typing import Callable, Type, TypeVar, Union

from typing_extensions import ParamSpec

from azure.ai.ml.constants._common import (
    DOCSTRING_DEFAULT_INDENTATION,
    DOCSTRING_TEMPLATE,
    EXPERIMENTAL_CLASS_MESSAGE,
    EXPERIMENTAL_LINK_MESSAGE,
    EXPERIMENTAL_METHOD_MESSAGE,
)

_warning_cache = set()
module_logger = logging.getLogger(__name__)

TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
P = ParamSpec("P")
T = TypeVar("T")


def experimental(wrapped: TExperimental) -> TExperimental:
    """Add experimental tag to a class or a method.

    :param wrapped: Either a Class or Function to mark as experimental
    :type wrapped: TExperimental
    :return: The wrapped class or method
    :rtype: TExperimental
    """
    if inspect.isclass(wrapped):
        return _add_class_docstring(wrapped)
    if inspect.isfunction(wrapped):
        return _add_method_docstring(wrapped)
    return wrapped


def _add_class_docstring(cls: Type[T]) -> Type[T]:
    """Add experimental tag to the class doc string.

    :return: The updated class
    :rtype: Type[T]
    """

    P2 = ParamSpec("P2")

    def _add_class_warning(func: Callable[P2, None]) -> Callable[P2, None]:
        """Add warning message for class __init__.

        :param func: The original __init__ function
        :type func: Callable[P2, None]
        :return: Updated __init__
        :rtype: Callable[P2, None]
        """

        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            message = "Class {0}: {1} {2}".format(cls.__name__, EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
            if not _should_skip_warning() and not _is_warning_cached(message):
                module_logger.warning(message)
            return func(*args, **kwargs)

        return wrapped

    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_CLASS_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
    if cls.__doc__:
        cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
    else:
        cls.__doc__ = doc_string + ">"
    cls.__init__ = _add_class_warning(cls.__init__)
    return cls


def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
    """Add experimental tag to the method doc string.

    :param func: The function to update
    :type func: Callable[P, T]
    :return: A wrapped method marked as experimental
    :rtype: Callable[P,T]
    """
    doc_string = DOCSTRING_TEMPLATE.format(EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
    if func.__doc__:
        func.__doc__ = _add_note_to_docstring(func.__doc__, doc_string)
    else:
        # '>' is required. Otherwise the note section can't be generated
        func.__doc__ = doc_string + ">"

    @functools.wraps(func)
    def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
        message = "Method {0}: {1} {2}".format(func.__name__, EXPERIMENTAL_METHOD_MESSAGE, EXPERIMENTAL_LINK_MESSAGE)
        if not _should_skip_warning() and not _is_warning_cached(message):
            module_logger.warning(message)
        return func(*args, **kwargs)

    return wrapped


def _add_note_to_docstring(doc_string: str, note: str) -> str:
    """Adds experimental note to docstring at the top and correctly indents original docstring.

    :param doc_string: The docstring
    :type doc_string: str
    :param note: The note to add to the docstring
    :type note: str
    :return: Updated docstring
    :rtype: str
    """
    indent = _get_indentation_size(doc_string)
    doc_string = doc_string.rjust(len(doc_string) + indent)
    return note + doc_string


def _get_indentation_size(doc_string: str) -> int:
    """Finds the minimum indentation of all non-blank lines after the first line.

    :param doc_string: The docstring
    :type doc_string: str
    :return: Minimum number of indentation of the docstring
    :rtype: int
    """
    lines = doc_string.expandtabs().splitlines()
    indent = sys.maxsize
    for line in lines[1:]:
        stripped = line.lstrip()
        if stripped:
            indent = min(indent, len(line) - len(stripped))
    return indent if indent < sys.maxsize else DOCSTRING_DEFAULT_INDENTATION


def _should_skip_warning():
    skip_warning_msg = False

    # Cases where we want to suppress the warning:
    # 1. When converting from REST object to SDK object
    for frame in inspect.stack():
        if frame.function == "_from_rest_object":
            skip_warning_msg = True
            break

    return skip_warning_msg


def _is_warning_cached(warning_msg):
    # use cache to make sure we only print same warning message once under same session
    # this prevents duplicated warnings got printed when user does a loop call on a method or a class
    if warning_msg in _warning_cache:
        return True
    _warning_cache.add(warning_msg)
    return False