about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/entities/_monitoring/definition.py
blob: 3b81be1ee0c24fa2986d5fce780ce10b12629e80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

from typing import Any, Dict, Optional, Union

from typing_extensions import Literal

from azure.ai.ml._restclient.v2023_06_01_preview.models import AzMonMonitoringAlertNotificationSettings
from azure.ai.ml._restclient.v2023_06_01_preview.models import MonitorDefinition as RestMonitorDefinition
from azure.ai.ml.constants._monitoring import (
    AZMONITORING,
    DEFAULT_DATA_DRIFT_SIGNAL_NAME,
    DEFAULT_DATA_QUALITY_SIGNAL_NAME,
    DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME,
    DEFAULT_TOKEN_USAGE_SIGNAL_NAME,
    MonitorTargetTasks,
)
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification
from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute
from azure.ai.ml.entities._monitoring.signals import (
    CustomMonitoringSignal,
    DataDriftSignal,
    DataQualitySignal,
    FeatureAttributionDriftSignal,
    GenerationSafetyQualitySignal,
    GenerationTokenStatisticsSignal,
    MonitoringSignal,
    PredictionDriftSignal,
)
from azure.ai.ml.entities._monitoring.target import MonitoringTarget


class MonitorDefinition(RestTranslatableMixin):
    """Monitor definition

    :keyword compute: The Spark resource configuration to be associated with the monitor
    :paramtype compute: ~azure.ai.ml.entities.SparkResourceConfiguration
    :keyword monitoring_target: The ARM ID object associated with the model or deployment that is being monitored.
    :paramtype monitoring_target: Optional[~azure.ai.ml.entities.MonitoringTarget]
    :keyword monitoring_signals: The dictionary of signals to monitor. The key is the name of the signal and the value
        is the DataSignal object. Accepted values for the DataSignal objects are DataDriftSignal, DataQualitySignal,
        PredictionDriftSignal, FeatureAttributionDriftSignal, and CustomMonitoringSignal.
    :paramtype monitoring_signals: Optional[Dict[str, Union[~azure.ai.ml.entities.DataDriftSignal
        , ~azure.ai.ml.entities.DataQualitySignal, ~azure.ai.ml.entities.PredictionDriftSignal
        , ~azure.ai.ml.entities.FeatureAttributionDriftSignal
        , ~azure.ai.ml.entities.CustomMonitoringSignal
        , ~azure.ai.ml.entities.GenerationSafetyQualitySignal
        , ~azure.ai.ml.entities.GenerationTokenStatisticsSignal
        , ~azure.ai.ml.entities.ModelPerformanceSignal]]]
    :keyword alert_notification: The alert configuration for the monitor.
    :paramtype alert_notification: Optional[Union[Literal['azmonitoring'], ~azure.ai.ml.entities.AlertNotification]]

    .. admonition:: Example:

        .. literalinclude:: ../samples/ml_samples_spark_configurations.py
            :start-after: [START spark_monitor_definition]
            :end-before: [END spark_monitor_definition]
            :language: python
            :dedent: 8
            :caption: Creating Monitor definition.
    """

    def __init__(
        self,
        *,
        compute: ServerlessSparkCompute,
        monitoring_target: Optional[MonitoringTarget] = None,
        monitoring_signals: Dict[
            str,
            Union[
                DataDriftSignal,
                DataQualitySignal,
                PredictionDriftSignal,
                FeatureAttributionDriftSignal,
                CustomMonitoringSignal,
                GenerationSafetyQualitySignal,
                GenerationTokenStatisticsSignal,
            ],
        ] = None,  # type: ignore[assignment]
        alert_notification: Optional[Union[Literal["azmonitoring"], AlertNotification]] = None,
    ) -> None:
        self.compute = compute
        self.monitoring_target = monitoring_target
        self.monitoring_signals = monitoring_signals
        self.alert_notification = alert_notification

    def _to_rest_object(self, **kwargs: Any) -> RestMonitorDefinition:
        default_data_window_size = kwargs.get("default_data_window_size")
        ref_data_window_size = kwargs.get("ref_data_window_size")
        rest_alert_notification = None
        if self.alert_notification:
            if isinstance(self.alert_notification, str) and self.alert_notification.lower() == AZMONITORING:
                rest_alert_notification = AzMonMonitoringAlertNotificationSettings()
            else:
                if not isinstance(self.alert_notification, str):
                    rest_alert_notification = self.alert_notification._to_rest_object()

        if self.monitoring_signals is not None:
            _signals = {
                signal_name: signal._to_rest_object(
                    default_data_window_size=default_data_window_size,
                    ref_data_window_size=ref_data_window_size,
                )
                for signal_name, signal in self.monitoring_signals.items()
            }
        return RestMonitorDefinition(
            compute_configuration=self.compute._to_rest_object(),
            monitoring_target=self.monitoring_target._to_rest_object() if self.monitoring_target else None,
            signals=_signals,  # pylint: disable=possibly-used-before-assignment
            alert_notification_setting=rest_alert_notification,
        )

    @classmethod
    def _from_rest_object(
        cls,  # pylint: disable=unused-argument
        obj: RestMonitorDefinition,
        **kwargs: Any,
    ) -> "MonitorDefinition":
        from_rest_alert_notification: Any = None
        if obj.alert_notification_setting:
            if isinstance(obj.alert_notification_setting, AzMonMonitoringAlertNotificationSettings):
                from_rest_alert_notification = AZMONITORING
            else:
                from_rest_alert_notification = AlertNotification._from_rest_object(obj.alert_notification_setting)

        _monitoring_signals = {}
        for signal_name, signal in obj.signals.items():
            _monitoring_signals[signal_name] = MonitoringSignal._from_rest_object(signal)

        return cls(
            compute=ServerlessSparkCompute._from_rest_object(obj.compute_configuration),
            monitoring_target=(
                MonitoringTarget(
                    endpoint_deployment_id=obj.monitoring_target.deployment_id, ml_task=obj.monitoring_target.task_type
                )
                if obj.monitoring_target
                else None
            ),
            monitoring_signals=_monitoring_signals,  # type: ignore[arg-type]
            alert_notification=from_rest_alert_notification,
        )

    def _populate_default_signal_information(self) -> None:
        if (
            isinstance(self.monitoring_target, MonitoringTarget)
            and self.monitoring_target.ml_task is not None
            and self.monitoring_target.ml_task.lower()
            == MonitorTargetTasks.QUESTION_ANSWERING.lower()  # type: ignore[union-attr]
        ):
            self.monitoring_signals = {
                DEFAULT_TOKEN_USAGE_SIGNAL_NAME: GenerationTokenStatisticsSignal._get_default_token_statistics_signal(),
            }
        else:
            self.monitoring_signals = {
                DEFAULT_DATA_DRIFT_SIGNAL_NAME: DataDriftSignal._get_default_data_drift_signal(),
                DEFAULT_PREDICTION_DRIFT_SIGNAL_NAME: PredictionDriftSignal._get_default_prediction_drift_signal(),
                DEFAULT_DATA_QUALITY_SIGNAL_NAME: DataQualitySignal._get_default_data_quality_signal(),
            }