diff options
| author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
|---|---|---|
| committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
| commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
| tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring | |
| parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
| download | gn-ai-master.tar.gz | |
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring')
9 files changed, 747 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py new file mode 100644 index 00000000..29a4fcd3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py new file mode 100644 index 00000000..bd7fd69c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/alert_notification.py @@ -0,0 +1,19 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class AlertNotificationSchema(metaclass=PatchedSchemaMeta): + emails = fields.List(fields.Str) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.alert_notification import AlertNotification + + return AlertNotification(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py new file mode 100644 index 00000000..483b4ac5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/compute.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class ComputeConfigurationSchema(metaclass=PatchedSchemaMeta): + compute_type = fields.Str(allowed_values=["ServerlessSpark"]) + + +class ServerlessSparkComputeSchema(ComputeConfigurationSchema): + runtime_version = fields.Str() + instance_type = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.compute import ServerlessSparkCompute + + return ServerlessSparkCompute(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py new file mode 100644 index 00000000..d5a6a4f9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/input_data.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import MonitorDatasetContext +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import NestedField, StringTransformedEnum, UnionField +from azure.ai.ml._schema.job.input_output_entry import MLTableInputSchema, DataInputSchema + + +class MonitorInputDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + target_columns = fields.Dict() + job_type = fields.Str() + uri = fields.Str() + + +class FixedInputDataSchema(MonitorInputDataSchema): + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import FixedInputData + + return FixedInputData(**data) + + +class TrailingInputDataSchema(MonitorInputDataSchema): + window_size = fields.Str() + window_offset = fields.Str() + pre_processing_component_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import TrailingInputData + + return TrailingInputData(**data) + + +class StaticInputDataSchema(MonitorInputDataSchema): + pre_processing_component_id = fields.Str() + window_start = fields.String() + window_end = fields.String() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.input_data import StaticInputData + + return StaticInputData(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py new file mode 100644 index 00000000..3fe52c9d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/monitor_definition.py @@ -0,0 +1,53 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import AZMONITORING +from azure.ai.ml._schema.monitoring.target import MonitoringTargetSchema +from azure.ai.ml._schema.monitoring.compute import ServerlessSparkComputeSchema +from azure.ai.ml._schema.monitoring.signals import ( + DataDriftSignalSchema, + DataQualitySignalSchema, + PredictionDriftSignalSchema, + FeatureAttributionDriftSignalSchema, + CustomMonitoringSignalSchema, + GenerationSafetyQualitySchema, + ModelPerformanceSignalSchema, + GenerationTokenStatisticsSchema, +) +from azure.ai.ml._schema.monitoring.alert_notification import AlertNotificationSchema +from azure.ai.ml._schema.core.fields import NestedField, UnionField, StringTransformedEnum +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MonitorDefinitionSchema(metaclass=PatchedSchemaMeta): + compute = NestedField(ServerlessSparkComputeSchema) + monitoring_target = NestedField(MonitoringTargetSchema) + monitoring_signals = fields.Dict( + keys=fields.Str(), + values=UnionField( + union_fields=[ + NestedField(DataDriftSignalSchema), + NestedField(DataQualitySignalSchema), + NestedField(PredictionDriftSignalSchema), + NestedField(FeatureAttributionDriftSignalSchema), + NestedField(CustomMonitoringSignalSchema), + NestedField(GenerationSafetyQualitySchema), + NestedField(ModelPerformanceSignalSchema), + NestedField(GenerationTokenStatisticsSchema), + ] + ), + ) + alert_notification = UnionField( + union_fields=[StringTransformedEnum(allowed_values=AZMONITORING), NestedField(AlertNotificationSchema)] + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.definition import MonitorDefinition + + return MonitorDefinition(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py new file mode 100644 index 00000000..a2034d33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/schedule.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from azure.ai.ml._schema.core.fields import NestedField +from azure.ai.ml._schema.monitoring.monitor_definition import MonitorDefinitionSchema +from azure.ai.ml._schema.schedule.schedule import ScheduleSchema + + +class MonitorScheduleSchema(ScheduleSchema): + create_monitor = NestedField(MonitorDefinitionSchema) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py new file mode 100644 index 00000000..4f55393b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/signals.py @@ -0,0 +1,348 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load, pre_dump, ValidationError + +from azure.ai.ml._schema.job.input_output_entry import DataInputSchema, MLTableInputSchema +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._monitoring import ( + MonitorSignalType, + ALL_FEATURES, + MonitorModelType, + MonitorDatasetContext, + FADColumnNames, +) +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, UnionField, StringTransformedEnum +from azure.ai.ml._schema.monitoring.thresholds import ( + DataDriftMetricThresholdSchema, + DataQualityMetricThresholdSchema, + PredictionDriftMetricThresholdSchema, + FeatureAttributionDriftMetricThresholdSchema, + ModelPerformanceMetricThresholdSchema, + CustomMonitoringMetricThresholdSchema, + GenerationSafetyQualityMetricThresholdSchema, + GenerationTokenStatisticsMonitorMetricThresholdSchema, +) + + +class DataSegmentSchema(metaclass=PatchedSchemaMeta): + feature_name = fields.Str() + feature_values = fields.List(fields.Str) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataSegment + + return DataSegment(**data) + + +class MonitorFeatureFilterSchema(metaclass=PatchedSchemaMeta): + top_n_feature_importance = fields.Int() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter + + if not isinstance(data, MonitorFeatureFilter): + raise ValidationError("Cannot dump non-MonitorFeatureFilter object into MonitorFeatureFilter") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import MonitorFeatureFilter + + return MonitorFeatureFilter(**data) + + +class BaselineDataRangeSchema(metaclass=PatchedSchemaMeta): + window_start = fields.Str() + window_end = fields.Str() + lookback_window_size = fields.Str() + lookback_window_offset = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import BaselineDataRange + + return BaselineDataRange(**data) + + +class ProductionDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + pre_processing_component = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ProductionData + + return ProductionData(**data) + + +class ReferenceDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + pre_processing_component = fields.Str() + target_column_name = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + data_column_names = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ReferenceData + + return ReferenceData(**data) + + +class MonitoringSignalSchema(metaclass=PatchedSchemaMeta): + production_data = NestedField(ProductionDataSchema) + reference_data = NestedField(ReferenceDataSchema) + properties = fields.Dict() + alert_enabled = fields.Bool() + + +class DataSignalSchema(MonitoringSignalSchema): + features = UnionField( + union_fields=[ + NestedField(MonitorFeatureFilterSchema), + StringTransformedEnum(allowed_values=ALL_FEATURES), + fields.List(fields.Str), + ] + ) + feature_type_override = fields.Dict() + + +class DataDriftSignalSchema(DataSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_DRIFT, required=True) + metric_thresholds = NestedField(DataDriftMetricThresholdSchema) + data_segment = NestedField(DataSegmentSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataDriftSignal + + if not isinstance(data, DataDriftSignal): + raise ValidationError("Cannot dump non-DataDriftSignal object into DataDriftSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataDriftSignal + + data.pop("type", None) + return DataDriftSignal(**data) + + +class DataQualitySignalSchema(DataSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.DATA_QUALITY, required=True) + metric_thresholds = NestedField(DataQualityMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataQualitySignal + + if not isinstance(data, DataQualitySignal): + raise ValidationError("Cannot dump non-DataQualitySignal object into DataQualitySignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import DataQualitySignal + + data.pop("type", None) + return DataQualitySignal(**data) + + +class PredictionDriftSignalSchema(MonitoringSignalSchema): + type = StringTransformedEnum(allowed_values=MonitorSignalType.PREDICTION_DRIFT, required=True) + metric_thresholds = NestedField(PredictionDriftMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal + + if not isinstance(data, PredictionDriftSignal): + raise ValidationError("Cannot dump non-PredictionDriftSignal object into PredictionDriftSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import PredictionDriftSignal + + data.pop("type", None) + return PredictionDriftSignal(**data) + + +class ModelSignalSchema(MonitoringSignalSchema): + model_type = StringTransformedEnum(allowed_values=[MonitorModelType.CLASSIFICATION, MonitorModelType.REGRESSION]) + + +class FADProductionDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_context = StringTransformedEnum(allowed_values=[o.value for o in MonitorDatasetContext]) + data_column_names = fields.Dict( + keys=StringTransformedEnum(allowed_values=[o.value for o in FADColumnNames]), values=fields.Str() + ) + pre_processing_component = fields.Str() + data_window = NestedField(BaselineDataRangeSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FADProductionData + + return FADProductionData(**data) + + +class FeatureAttributionDriftSignalSchema(metaclass=PatchedSchemaMeta): + production_data = fields.List(NestedField(FADProductionDataSchema)) + reference_data = NestedField(ReferenceDataSchema) + alert_enabled = fields.Bool() + type = StringTransformedEnum(allowed_values=MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT, required=True) + metric_thresholds = NestedField(FeatureAttributionDriftMetricThresholdSchema) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal + + if not isinstance(data, FeatureAttributionDriftSignal): + raise ValidationError( + "Cannot dump non-FeatureAttributionDriftSignal object into FeatureAttributionDriftSignal" + ) + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import FeatureAttributionDriftSignal + + data.pop("type", None) + return FeatureAttributionDriftSignal(**data) + + +class ModelPerformanceSignalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.MODEL_PERFORMANCE, required=True) + production_data = NestedField(ProductionDataSchema) + reference_data = NestedField(ReferenceDataSchema) + data_segment = NestedField(DataSegmentSchema) + alert_enabled = fields.Bool() + metric_thresholds = NestedField(ModelPerformanceMetricThresholdSchema) + properties = fields.Dict() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal + + if not isinstance(data, ModelPerformanceSignal): + raise ValidationError("Cannot dump non-ModelPerformanceSignal object into ModelPerformanceSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import ModelPerformanceSignal + + data.pop("type", None) + return ModelPerformanceSignal(**data) + + +class ConnectionSchema(metaclass=PatchedSchemaMeta): + environment_variables = fields.Dict(keys=fields.Str(), values=fields.Str()) + secret_config = fields.Dict(keys=fields.Str(), values=fields.Str()) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import Connection + + return Connection(**data) + + +class CustomMonitoringSignalSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.CUSTOM, required=True) + connection = NestedField(ConnectionSchema) + component_id = ArmVersionedStr(azureml_type=AzureMLResourceType.COMPONENT) + metric_thresholds = fields.List(NestedField(CustomMonitoringMetricThresholdSchema)) + input_data = fields.Dict(keys=fields.Str(), values=NestedField(ReferenceDataSchema)) + alert_enabled = fields.Bool() + inputs = fields.Dict( + keys=fields.Str, values=UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + ) + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal + + if not isinstance(data, CustomMonitoringSignal): + raise ValidationError("Cannot dump non-CustomMonitoringSignal object into CustomMonitoringSignal") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import CustomMonitoringSignal + + data.pop("type", None) + return CustomMonitoringSignal(**data) + + +class LlmDataSchema(metaclass=PatchedSchemaMeta): + input_data = UnionField(union_fields=[NestedField(DataInputSchema), NestedField(MLTableInputSchema)]) + data_column_names = fields.Dict() + data_window = NestedField(BaselineDataRangeSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import LlmData + + return LlmData(**data) + + +class GenerationSafetyQualitySchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_SAFETY_QUALITY, required=True) + production_data = fields.List(NestedField(LlmDataSchema)) + connection_id = fields.Str() + metric_thresholds = NestedField(GenerationSafetyQualityMetricThresholdSchema) + alert_enabled = fields.Bool() + properties = fields.Dict() + sampling_rate = fields.Float() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal + + if not isinstance(data, GenerationSafetyQualitySignal): + raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationSafetyQualitySignal + + data.pop("type", None) + return GenerationSafetyQualitySignal(**data) + + +class GenerationTokenStatisticsSchema(metaclass=PatchedSchemaMeta): + type = StringTransformedEnum(allowed_values=MonitorSignalType.GENERATION_TOKEN_STATISTICS, required=True) + production_data = NestedField(LlmDataSchema) + metric_thresholds = NestedField(GenerationTokenStatisticsMonitorMetricThresholdSchema) + alert_enabled = fields.Bool() + properties = fields.Dict() + sampling_rate = fields.Float() + + @pre_dump + def predump(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal + + if not isinstance(data, GenerationTokenStatisticsSignal): + raise ValidationError("Cannot dump non-GenerationSafetyQuality object into GenerationSafetyQuality") + return data + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.signals import GenerationTokenStatisticsSignal + + data.pop("type", None) + return GenerationTokenStatisticsSignal(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py new file mode 100644 index 00000000..6d3032ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/target.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +from marshmallow import fields, post_load + + +from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._monitoring import MonitorTargetTasks +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._schema.core.fields import ArmVersionedStr, StringTransformedEnum + + +class MonitoringTargetSchema(metaclass=PatchedSchemaMeta): + model_id = ArmVersionedStr(azureml_type=AzureMLResourceType.MODEL) + ml_task = StringTransformedEnum(allowed_values=[o.value for o in MonitorTargetTasks]) + endpoint_deployment_id = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.target import MonitoringTarget + + return MonitoringTarget(**data) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py new file mode 100644 index 00000000..b7970fca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_schema/monitoring/thresholds.py @@ -0,0 +1,211 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument, name-too-long + +from marshmallow import fields, post_load + +from azure.ai.ml.constants._monitoring import MonitorFeatureType +from azure.ai.ml._schema.core.fields import StringTransformedEnum, NestedField +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta + + +class MetricThresholdSchema(metaclass=PatchedSchemaMeta): + threshold = fields.Number() + + +class NumericalDriftMetricsSchema(metaclass=PatchedSchemaMeta): + jensen_shannon_distance = fields.Number() + normalized_wasserstein_distance = fields.Number() + population_stability_index = fields.Number() + two_sample_kolmogorov_smirnov_test = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import NumericalDriftMetrics + + return NumericalDriftMetrics(**data) + + +class CategoricalDriftMetricsSchema(metaclass=PatchedSchemaMeta): + jensen_shannon_distance = fields.Number() + population_stability_index = fields.Number() + pearsons_chi_squared_test = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import CategoricalDriftMetrics + + return CategoricalDriftMetrics(**data) + + +class DataDriftMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + + numerical = NestedField(NumericalDriftMetricsSchema) + categorical = NestedField(CategoricalDriftMetricsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataDriftMetricThreshold + + return DataDriftMetricThreshold(**data) + + +class DataQualityMetricsNumericalSchema(metaclass=PatchedSchemaMeta): + null_value_rate = fields.Number() + data_type_error_rate = fields.Number() + out_of_bounds_rate = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsNumerical + + return DataQualityMetricsNumerical(**data) + + +class DataQualityMetricsCategoricalSchema(metaclass=PatchedSchemaMeta): + null_value_rate = fields.Number() + data_type_error_rate = fields.Number() + out_of_bounds_rate = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricsCategorical + + return DataQualityMetricsCategorical(**data) + + +class DataQualityMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + numerical = NestedField(DataQualityMetricsNumericalSchema) + categorical = NestedField(DataQualityMetricsCategoricalSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import DataQualityMetricThreshold + + return DataQualityMetricThreshold(**data) + + +class PredictionDriftMetricThresholdSchema(MetricThresholdSchema): + data_type = StringTransformedEnum(allowed_values=[MonitorFeatureType.NUMERICAL, MonitorFeatureType.CATEGORICAL]) + numerical = NestedField(NumericalDriftMetricsSchema) + categorical = NestedField(CategoricalDriftMetricsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import PredictionDriftMetricThreshold + + return PredictionDriftMetricThreshold(**data) + + +# pylint: disable-next=name-too-long +class FeatureAttributionDriftMetricThresholdSchema(MetricThresholdSchema): + normalized_discounted_cumulative_gain = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import FeatureAttributionDriftMetricThreshold + + return FeatureAttributionDriftMetricThreshold(**data) + + +class ModelPerformanceClassificationThresholdsSchema(metaclass=PatchedSchemaMeta): + accuracy = fields.Number() + precision = fields.Number() + recall = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceClassificationThresholds + + return ModelPerformanceClassificationThresholds(**data) + + +class ModelPerformanceRegressionThresholdsSchema(metaclass=PatchedSchemaMeta): + mae = fields.Number() + mse = fields.Number() + rmse = fields.Number() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceRegressionThresholds + + return ModelPerformanceRegressionThresholds(**data) + + +class ModelPerformanceMetricThresholdSchema(MetricThresholdSchema): + classification = NestedField(ModelPerformanceClassificationThresholdsSchema) + regression = NestedField(ModelPerformanceRegressionThresholdsSchema) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import ModelPerformanceMetricThreshold + + return ModelPerformanceMetricThreshold(**data) + + +class CustomMonitoringMetricThresholdSchema(MetricThresholdSchema): + metric_name = fields.Str() + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import CustomMonitoringMetricThreshold + + return CustomMonitoringMetricThreshold(**data) + + +class GenerationSafetyQualityMetricThresholdSchema(metaclass=PatchedSchemaMeta): # pylint: disable=name-too-long + groundedness = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_groundedness_pass_rate", "acceptable_groundedness_score_per_instance"] + ), + values=fields.Number(), + ) + relevance = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_relevance_pass_rate", "acceptable_relevance_score_per_instance"] + ), + values=fields.Number(), + ) + coherence = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_coherence_pass_rate", "acceptable_coherence_score_per_instance"] + ), + values=fields.Number(), + ) + fluency = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_fluency_pass_rate", "acceptable_fluency_score_per_instance"] + ), + values=fields.Number(), + ) + similarity = fields.Dict( + keys=StringTransformedEnum( + allowed_values=["aggregated_similarity_pass_rate", "acceptable_similarity_score_per_instance"] + ), + values=fields.Number(), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import GenerationSafetyQualityMonitoringMetricThreshold + + return GenerationSafetyQualityMonitoringMetricThreshold(**data) + + +class GenerationTokenStatisticsMonitorMetricThresholdSchema( + metaclass=PatchedSchemaMeta +): # pylint: disable=name-too-long + totaltoken = fields.Dict( + keys=StringTransformedEnum(allowed_values=["total_token_count", "total_token_count_per_group"]), + values=fields.Number(), + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._monitoring.thresholds import GenerationTokenStatisticsMonitorMetricThreshold + + return GenerationTokenStatisticsMonitorMetricThreshold(**data) |
