From 901159da45695da24a5206125910f02fc50169ce Mon Sep 17 00:00:00 2001 From: Efraim Flashner Date: Thu, 23 Apr 2020 15:50:37 +0300 Subject: [PATCH] add keras metrics --- keras/backend/tensorflow_backend.py | 12 + keras/metrics.py | 584 ++++++++++++++++++++++++++++ keras/utils/__init__.py | 2 + keras/utils/losses_utils.py | 177 +++++++++ keras/utils/metrics_utils.py | 278 +++++++++++++ 5 files changed, 1053 insertions(+) create mode 100644 keras/utils/losses_utils.py create mode 100644 keras/utils/metrics_utils.py diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index bcb8be0..a2870f5 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -4453,3 +4453,15 @@ def local_conv2d(inputs, kernel, kernel_size, strides, output_shape, data_format else: output = permute_dimensions(output, (2, 0, 1, 3)) return output + +#get_graph = tf_keras_backend.get_graph + +#def is_symbolic(x): +# return isinstance(x, tf.Tensor) and hasattr(x, 'op') + +def size(x, name=None): +# if is_symbolic(x): +# with get_graph().as_default(): +# return tf.size(x) + return tf.size(x, name=name) + diff --git a/keras/metrics.py b/keras/metrics.py index 8e3df1f..8f57910 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -4,8 +4,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import six +import types + from . import backend as K +from .engine.base_layer import Layer from .losses import mean_squared_error from .losses import mean_absolute_error from .losses import mean_absolute_percentage_error @@ -19,10 +23,201 @@ from .losses import binary_crossentropy from .losses import kullback_leibler_divergence from .losses import poisson from .losses import cosine_proximity +from .utils import losses_utils +from .utils import metrics_utils from .utils.generic_utils import deserialize_keras_object from .utils.generic_utils import serialize_keras_object +@six.add_metaclass(abc.ABCMeta) +class Metric(Layer): + """Encapsulates metric logic and state. + + Standalone usage: + ```python + m = SomeMetric(...) + for input in ...: + m.update_state(input) + m.result() + ``` + + Usage with the `compile` API: + ```python + model.compile(optimizer='rmsprop', + loss=keras.losses.categorical_crossentropy, + metrics=[keras.metrics.CategoricalAccuracy()]) + ``` + + To be implemented by subclasses: + * `__init__()`: All state variables should be created in this method by + calling `self.add_weight()` like: `self.var = self.add_weight(...)` + * `update_state()`: Has all updates to the state variables like: + self.var.assign_add(...). + * `result()`: Computes and returns a value for the metric + from the state variables. + """ + + def __init__(self, name=None, dtype=None, **kwargs): + super(Metric, self).__init__(name=name, dtype=dtype, **kwargs) + self.stateful = True # All metric layers are stateful. + self.built = True + self.dtype = K.floatx() if dtype is None else dtype + + def __new__(cls, *args, **kwargs): + obj = super(Metric, cls).__new__(cls) + update_state_fn = obj.update_state + + obj.update_state = types.MethodType( + metrics_utils.update_state_wrapper(update_state_fn), obj) + return obj + + def __call__(self, *args, **kwargs): + """Accumulates statistics and then computes metric result value.""" + update_op = self.update_state(*args, **kwargs) + return self.result() + + def get_config(self): + """Returns the serializable config of the metric.""" + return {'name': self.name, 'dtype': self.dtype} + + def reset_states(self): + """Resets all of the metric state variables. + This function is called between epochs/steps, + when a metric is evaluated during training. + """ + K.batch_set_value([(v, 0) for v in self.weights]) + + @abc.abstractmethod + def update_state(self, *args, **kwargs): + """Accumulates statistics for the metric. """ + raise NotImplementedError('Must be implemented in subclasses.') + + @abc.abstractmethod + def result(self): + """Computes and returns the metric value tensor. + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + """ + raise NotImplementedError('Must be implemented in subclasses.') + + # For use by subclasses # + def add_weight(self, + name, + shape=(), + initializer=None, + dtype=None): + """Adds state variable. Only for use by subclasses.""" + return super(Metric, self).add_weight( + name=name, + shape=shape, + dtype=self.dtype if dtype is None else dtype, + trainable=False, + initializer=initializer) + + # End: For use by subclasses ### + + +class Reduce(Metric): + """Encapsulates metrics that perform a reduce operation on the values.""" + + def __init__(self, reduction, name, dtype=None): + """Creates a `Reduce` instance. + # Arguments + reduction: a metrics `Reduction` enum value. + name: string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + super(Reduce, self).__init__(name=name, dtype=dtype) + self.reduction = reduction + self.total = self.add_weight('total', initializer='zeros') + if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, + metrics_utils.Reduction.WEIGHTED_MEAN]: + self.count = self.add_weight('count', initializer='zeros') + + def update_state(self, values, sample_weight=None): + """Accumulates statistics for computing the reduction metric. + For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, + then the value of `result()` is 4. If the `sample_weight` is specified as + [1, 1, 0, 0] then value of `result()` would be 2. + # Arguments + values: Per-example value. + sample_weight: Optional weighting of each example. Defaults to 1. + """ + values = K.cast(values, self.dtype) + if sample_weight is not None: + sample_weight = K.cast(sample_weight, self.dtype) + # Update dimensions of weights to match with values if possible. + values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( + values, sample_weight=sample_weight) + + # Broadcast weights if possible. + sample_weight = losses_utils.broadcast_weights(sample_weight, values) + values = values * sample_weight + + value_sum = K.sum(values) + update_total_op = K.update_add(self.total, value_sum) + + # Exit early if the reduction doesn't have a denominator. + if self.reduction == metrics_utils.Reduction.SUM: + return update_total_op + + # Update `count` for reductions that require a denominator. + if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: + num_values = K.cast(K.size(values), self.dtype) + elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: + if sample_weight is None: + num_values = K.cast(K.size(values), self.dtype) + else: + num_values = K.sum(sample_weight) + else: + raise NotImplementedError( + 'reduction [%s] not implemented' % self.reduction) + + with K.control_dependencies([update_total_op]): + return K.update_add(self.count, num_values) + + def result(self): + if self.reduction == metrics_utils.Reduction.SUM: + return self.total + elif self.reduction in [ + metrics_utils.Reduction.WEIGHTED_MEAN, + metrics_utils.Reduction.SUM_OVER_BATCH_SIZE + ]: + return self.total / self.count + else: + raise NotImplementedError( + 'reduction [%s] not implemented' % self.reduction) + + +class Sum(Reduce): + """Computes the (weighted) sum of the given values. + + For example, if values is [1, 3, 5, 7] then the sum is 16. + If the weights were specified as [1, 1, 0, 0] then the sum would be 4. + + This metric creates one variable, `total`, that is used to compute the sum of + `values`. This is ultimately returned as `sum`. + If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 + to mask values. + + Standalone usage: + ```python + m = keras.metrics.Sum() + m.update_state([1, 3, 5, 7]) + m.result() + ``` + """ + + def __init__(self, name='sum', dtype=None): + """Creates a `Sum` instance. + # Arguments + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, + name=name, dtype=dtype) + + def binary_accuracy(y_true, y_pred): return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1) @@ -49,6 +244,395 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k), axis=-1) +class SensitivitySpecificityBase(Metric): + """Abstract base class for computing sensitivity and specificity. + + For additional information about specificity and sensitivity, see the + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + """ + + def __init__(self, value, num_thresholds=200, name=None, dtype=None): + super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) + if num_thresholds <= 0: + raise ValueError('`num_thresholds` must be > 0.') + self.value = value + self.true_positives = self.add_weight( + 'true_positives', + shape=(num_thresholds,), + initializer='zeros') + self.true_negatives = self.add_weight( + 'true_negatives', + shape=(num_thresholds,), + initializer='zeros') + self.false_positives = self.add_weight( + 'false_positives', + shape=(num_thresholds,), + initializer='zeros') + self.false_negatives = self.add_weight( + 'false_negatives', + shape=(num_thresholds,), + initializer='zeros') + + # Compute `num_thresholds` thresholds in [0, 1] + if num_thresholds == 1: + self.thresholds = [0.5] + else: + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + self.thresholds = [0.0] + thresholds + [1.0] + + def update_state(self, y_true, y_pred, sample_weight=None): + return metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, + }, + y_true, + y_pred, + thresholds=self.thresholds, + sample_weight=sample_weight) + + def reset_states(self): + num_thresholds = len(self.thresholds) + K.batch_set_value( + [(v, np.zeros((num_thresholds,))) for v in self.variables]) + + +class SensitivityAtSpecificity(SensitivitySpecificityBase): + """Computes the sensitivity at a given specificity. + + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such (tp / (tp + fn)). + `Specificity` measures the proportion of actual negatives that are correctly + identified as such (tn / (tn + fp)). + + This metric creates four local variables, `true_positives`, `true_negatives`, + `false_positives` and `false_negatives` that are used to compute the + sensitivity at the given specificity. The threshold for the given specificity + value is computed and used to evaluate the corresponding sensitivity. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + For additional information about specificity and sensitivity, see the + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile( + 'sgd', + loss='mse', + metrics=[keras.metrics.SensitivityAtSpecificity()]) + ``` + + # Arguments + specificity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given specificity. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__(self, specificity, num_thresholds=200, name=None, dtype=None): + if specificity < 0 or specificity > 1: + raise ValueError('`specificity` must be in the range [0, 1].') + self.specificity = specificity + self.num_thresholds = num_thresholds + super(SensitivityAtSpecificity, self).__init__( + specificity, num_thresholds=num_thresholds, name=name, dtype=dtype) + + def result(self): + # Calculate specificities at all the thresholds. + specificities = K.switch( + K.greater(self.true_negatives + self.false_positives, 0), + (self.true_negatives / (self.true_negatives + self.false_positives)), + K.zeros_like(self.thresholds)) + + # Find the index of the threshold where the specificity is closest to the + # given specificity. + min_index = K.argmin( + K.abs(specificities - self.value), axis=0) + min_index = K.cast(min_index, 'int32') + + # Compute sensitivity at that index. + return K.switch( + K.greater((self.true_positives[min_index] + + self.false_negatives[min_index]), 0), + (self.true_positives[min_index] / + (self.true_positives[min_index] + self.false_negatives[min_index])), + K.zeros_like(self.true_positives[min_index])) + + def get_config(self): + config = { + 'num_thresholds': self.num_thresholds, + 'specificity': self.specificity + } + base_config = super(SensitivityAtSpecificity, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class AUC(Metric): + """Computes the approximate AUC (Area under the curve) via a Riemann sum. + + This metric creates four local variables, `true_positives`, `true_negatives`, + `false_positives` and `false_negatives` that are used to compute the AUC. + To discretize the AUC curve, a linearly spaced set of thresholds is used to + compute pairs of recall and precision values. The area under the ROC-curve is + therefore computed using the height of the recall values by the false positive + rate, while the area under the PR-curve is the computed using the height of + the precision values by the recall. + + This value is ultimately returned as `auc`, an idempotent operation that + computes the area under a discretized curve of precision versus recall values + (computed using the aforementioned variables). The `num_thresholds` variable + controls the degree of discretization with larger numbers of thresholds more + closely approximating the true AUC. The quality of the approximation may vary + dramatically depending on `num_thresholds`. The `thresholds` parameter can be + used to manually specify thresholds which split the predictions more evenly. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC + approximation may be poor if this is not the case. Setting `summation_method` + to 'minoring' or 'majoring' can help quantify the error in the approximation + by providing lower or upper bound estimate of the AUC. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Usage with the compile API: + + ```python + model = keras.Model(inputs, outputs) + model.compile('sgd', loss='mse', metrics=[keras.metrics.AUC()]) + ``` + + # Arguments + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + curve: (Optional) Specifies the name of the curve to be computed, 'ROC' + [default] or 'PR' for the Precision-Recall-curve. + summation_method: (Optional) Specifies the Riemann summation method used + (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default], + applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates + (true/false) positives but not the ratio that is precision (see Davis + & Goadrich 2006 for details); 'minoring' that applies left summation + for increasing intervals and right summation for decreasing intervals; + 'majoring' that does the opposite. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will + be automatically included with these to correctly handle predictions + equal to exactly 0 or 1. + """ + + def __init__(self, + num_thresholds=200, + curve='ROC', + summation_method='interpolation', + name=None, + dtype=None, + thresholds=None): + # Validate configurations. + if (isinstance(curve, metrics_utils.AUCCurve) and + curve not in list(metrics_utils.AUCCurve)): + raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( + curve, list(metrics_utils.AUCCurve))) + if isinstance( + summation_method, + metrics_utils.AUCSummationMethod) and summation_method not in list( + metrics_utils.AUCSummationMethod): + raise ValueError( + 'Invalid summation method: "{}". Valid options are: "{}"'.format( + summation_method, list(metrics_utils.AUCSummationMethod))) + + # Update properties. + if thresholds is not None: + # If specified, use the supplied thresholds. + self.num_thresholds = len(thresholds) + 2 + thresholds = sorted(thresholds) + else: + if num_thresholds <= 1: + raise ValueError('`num_thresholds` must be > 1.') + + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in + # (0, 1). + self.num_thresholds = num_thresholds + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + + # Add an endpoint "threshold" below zero and above one for either + # threshold method to account for floating point imprecisions. + self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()] + + if isinstance(curve, metrics_utils.AUCCurve): + self.curve = curve + else: + self.curve = metrics_utils.AUCCurve.from_str(curve) + if isinstance(summation_method, metrics_utils.AUCSummationMethod): + self.summation_method = summation_method + else: + self.summation_method = metrics_utils.AUCSummationMethod.from_str( + summation_method) + super(AUC, self).__init__(name=name, dtype=dtype) + + # Create metric variables + self.true_positives = self.add_weight( + 'true_positives', + shape=(self.num_thresholds,), + initializer='zeros') + self.true_negatives = self.add_weight( + 'true_negatives', + shape=(self.num_thresholds,), + initializer='zeros') + self.false_positives = self.add_weight( + 'false_positives', + shape=(self.num_thresholds,), + initializer='zeros') + self.false_negatives = self.add_weight( + 'false_negatives', + shape=(self.num_thresholds,), + initializer='zeros') + + def update_state(self, y_true, y_pred, sample_weight=None): + return metrics_utils.update_confusion_matrix_variables({ + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, + }, y_true, y_pred, self.thresholds, sample_weight=sample_weight) + + def interpolate_pr_auc(self): + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + + https://www.biostat.wisc.edu/~page/rocpr.pdf + + Note here we derive & use a closed formula not present in the paper + as follows: + + Precision = TP / (TP + FP) = TP / P + + Modeling all of TP (true positive), FP (false positive) and their sum + P = TP + FP (predicted positive) as varying linearly within each interval + [A, B] between successive thresholds, we get + + Precision slope = dTP / dP + = (TP_B - TP_A) / (P_B - P_A) + = (TP - TP_A) / (P - P_A) + Precision = (TP_A + slope * (P - P_A)) / P + + The area within the interval is (slope / total_pos_weight) times + + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} + + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) + + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight + + where dTP == TP_B - TP_A. + + Note that when P_A == 0 the above calculation simplifies into + + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) + + which is really equivalent to imputing constant precision throughout the + first bucket having >0 true positives. + + # Returns + pr_auc: an approximation of the area under the P-R curve. + """ + dtp = self.true_positives[:self.num_thresholds - + 1] - self.true_positives[1:] + p = self.true_positives + self.false_positives + dp = p[:self.num_thresholds - 1] - p[1:] + + prec_slope = dtp / K.maximum(dp, 0) + intercept = self.true_positives[1:] - (prec_slope * p[1:]) + + # Logical and + pMin = K.expand_dims(p[:self.num_thresholds - 1] > 0, 0) + pMax = K.expand_dims(p[1:] > 0, 0) + are_different = K.concatenate([pMin, pMax], axis=0) + switch_condition = K.all(are_different, axis=0) + + safe_p_ratio = K.switch( + switch_condition, + p[:self.num_thresholds - 1] / K.maximum(p[1:], 0), + K.ones_like(p[1:])) + + numer = prec_slope * (dtp + intercept * K.log(safe_p_ratio)) + denom = K.maximum(self.true_positives[1:] + self.false_negatives[1:], 0) + return K.sum((numer / denom)) + + def result(self): + if (self.curve == metrics_utils.AUCCurve.PR and + (self.summation_method == + metrics_utils.AUCSummationMethod.INTERPOLATION)): + # This use case is different and is handled separately. + return self.interpolate_pr_auc() + + # Set `x` and `y` values for the curves based on `curve` config. + recall = K.switch( + K.greater((self.true_positives), 0), + (self.true_positives / + (self.true_positives + self.false_negatives)), + K.zeros_like(self.true_positives)) + if self.curve == metrics_utils.AUCCurve.ROC: + fp_rate = K.switch( + K.greater((self.false_positives), 0), + (self.false_positives / + (self.false_positives + self.true_negatives)), + K.zeros_like(self.false_positives)) + x = fp_rate + y = recall + else: # curve == 'PR'. + precision = K.switch( + K.greater((self.true_positives), 0), + (self.true_positives / (self.true_positives + self.false_positives)), + K.zeros_like(self.true_positives)) + x = recall + y = precision + + # Find the rectangle heights based on `summation_method`. + if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: + # Note: the case ('PR', 'interpolation') has been handled above. + heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. + elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: + heights = K.minimum(y[:self.num_thresholds - 1], y[1:]) + else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: + heights = K.maximum(y[:self.num_thresholds - 1], y[1:]) + + # Sum up the areas of all the rectangles. + return K.sum((x[:self.num_thresholds - 1] - x[1:]) * heights) + + def reset_states(self): + K.batch_set_value( + [(v, np.zeros((self.num_thresholds,))) for v in self.variables]) + + def get_config(self): + config = { + 'num_thresholds': self.num_thresholds, + 'curve': self.curve.value, + 'summation_method': self.summation_method.value, + # We remove the endpoint thresholds as an inverse of how the thresholds + # were initialized. This ensures that a metric initialized from this + # config has the same thresholds. + 'thresholds': self.thresholds[1:-1], + } + base_config = super(AUC, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + # Aliases diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py index 8cc39d5..65af329 100644 --- a/keras/utils/__init__.py +++ b/keras/utils/__init__.py @@ -4,6 +4,8 @@ from . import generic_utils from . import data_utils from . import io_utils from . import conv_utils +from . import losses_utils +from . import metrics_utils # Globally-importable utils. from .io_utils import HDF5Matrix diff --git a/keras/utils/losses_utils.py b/keras/utils/losses_utils.py new file mode 100644 index 0000000..617ebb7 --- /dev/null +++ b/keras/utils/losses_utils.py @@ -0,0 +1,177 @@ +"""Utilities related to losses.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from .. import backend as K + + +class Reduction(object): + """Types of loss reduction. + + Contains the following values: + + * `NONE`: Un-reduced weighted losses with the same shape as input. When this + reduction type used with built-in Keras training loops like + `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but + the reported loss will be a scalar value. + * `SUM`: Scalar sum of weighted losses. + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. + """ + + NONE = 'none' + SUM = 'sum' + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' + + @classmethod + def all(cls): + return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) + + @classmethod + def validate(cls, key): + if key not in cls.all(): + raise ValueError('Invalid Reduction Key %s.' % key) + + +def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): + """Squeeze or expand last dimension if needed. + + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1. + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 + from the new rank of `y_pred`. + If `sample_weight` is scalar, it is kept scalar. + + # Arguments + y_pred: Predicted values, a `Tensor` of arbitrary dimensions. + y_true: Optional label `Tensor` whose dimensions match `y_pred`. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match + `y_pred`. + + # Returns + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has + the last dimension squeezed, `sample_weight` could be extended by one + dimension. + """ + if y_true is not None: + y_pred_rank = K.ndim(y_pred) + y_pred_shape = K.int_shape(y_pred) + y_true_rank = K.ndim(y_true) + y_true_shape = K.int_shape(y_true) + + if (y_pred_rank - y_true_rank == 1) and (y_pred_shape[-1] == 1): + y_pred = K.squeeze(y_pred, -1) + elif (y_true_rank - y_pred_rank == 1) and (y_true_shape[-1] == 1): + y_true = K.squeeze(y_true, -1) + + if sample_weight is None: + return y_pred, y_true + + y_pred_rank = K.ndim(y_pred) + weights_rank = K.ndim(sample_weight) + if weights_rank != 0: + if weights_rank - y_pred_rank == 1: + sample_weight = K.squeeze(sample_weight, -1) + elif y_pred_rank - weights_rank == 1: + sample_weight = K.expand_dims(sample_weight, -1) + return y_pred, y_true, sample_weight + + +def _num_elements(losses): + """Computes the number of elements in `losses` tensor.""" + with K.name_scope('num_elements') as scope: + return K.cast(K.size(losses, name=scope), losses.dtype) + + +def reduce_weighted_loss(weighted_losses, reduction=Reduction.SUM_OVER_BATCH_SIZE): + """Reduces the individual weighted loss measurements.""" + if reduction == Reduction.NONE: + loss = weighted_losses + else: + loss = K.sum(weighted_losses) + if reduction == Reduction.SUM_OVER_BATCH_SIZE: + loss = loss / _num_elements(weighted_losses) + return loss + + +def broadcast_weights(values, sample_weight): + # Broadcast weights if possible. + weights_shape = K.int_shape(sample_weight) + values_shape = K.int_shape(values) + + if values_shape != weights_shape: + weights_rank = K.ndim(sample_weight) + values_rank = K.ndim(values) + + # Raise error if ndim of weights is > values. + if weights_rank > values_rank: + raise ValueError( + 'Incompatible shapes: `values` {} vs `sample_weight` {}'.format( + values_shape, weights_shape)) + + # Expand dim of weights to match ndim of values, if required. + for i in range(weights_rank, values_rank): + sample_weight = K.expand_dims(sample_weight, axis=i) + + if weights_shape is not None and values_shape is not None: + for i in range(weights_rank): + if (weights_shape[i] is not None and + values_shape[i] is not None and + weights_shape[i] != values_shape[i]): + # Cannot be broadcasted. + if weights_shape[i] != 1: + raise ValueError( + 'Incompatible shapes: `values` {} vs ' + '`sample_weight` {}'.format( + values_shape, weights_shape)) + sample_weight = K.repeat_elements( + sample_weight, values_shape[i], axis=i) + return sample_weight + + +def compute_weighted_loss(losses, + sample_weight=None, + reduction=Reduction.SUM_OVER_BATCH_SIZE, + name=None): + """Computes the weighted loss. + + # Arguments + losses: `Tensor` of shape `[batch_size, d1, ... dN]`. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as + ` losses`, or be broadcastable to `losses`. + reduction: (Optional) Type of Reduction to apply to loss. + Default value is `SUM_OVER_BATCH_SIZE`. + name: Optional name for the op. + + # Raises + ValueError: If the shape of `sample_weight` is not compatible with `losses`. + + # Returns + Weighted loss `Tensor` of the same type as `losses`. If `reduction` is + `NONE`, this has the same shape as `losses`; otherwise, it is scalar. + """ + Reduction.validate(reduction) + if sample_weight is None: + sample_weight = 1.0 + with K.name_scope(name or 'weighted_loss'): + input_dtype = K.dtype(losses) + losses = K.cast(losses, K.floatx()) + sample_weight = K.cast(sample_weight, K.floatx()) + + # Update dimensions of `sample_weight` to match with `losses` if possible. + losses, _, sample_weight = squeeze_or_expand_dimensions( + losses, None, sample_weight) + + # Broadcast weights if possible. + sample_weight = broadcast_weights(losses, sample_weight) + + # Apply weights to losses. + weighted_losses = sample_weight * losses + + # Apply reduction function to the individual weighted losses. + loss = reduce_weighted_loss(weighted_losses, reduction) + # Convert the result back to the input type. + loss = K.cast(loss, input_dtype) + return loss + diff --git a/keras/utils/metrics_utils.py b/keras/utils/metrics_utils.py new file mode 100644 index 0000000..e6a5bb0 --- /dev/null +++ b/keras/utils/metrics_utils.py @@ -0,0 +1,278 @@ +"""Utilities related to metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from enum import Enum + +from .. import backend as K +from . import losses_utils + +NEG_INF = -1e10 + +class Reduction(object): + """Types of metrics reduction. + Contains the following values: + * `SUM`: Scalar sum of weighted values. + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` of weighted values divided by + number of elements in values. + * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. + """ + + SUM = 'sum' + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' + WEIGHTED_MEAN = 'weighted_mean' + + +def update_state_wrapper(update_state_fn): + """Decorator to wrap metric `update_state()` with `add_update()`. + # Arguments + update_state_fn: function that accumulates metric statistics. + # Returns + Decorated function that wraps `update_state_fn()` with `add_update()`. + """ + def decorated(metric_obj, *args, **kwargs): + """Decorated function with `add_update()`.""" + + update_op = update_state_fn(*args, **kwargs) + metric_obj.add_update(update_op) + return update_op + + return decorated + +def result_wrapper(result_fn): + """Decorator to wrap metric `result()` with identity op. + Wrapping result in identity so that control dependency between + update_op from `update_state` and result works in case result returns + a tensor. + # Arguments + result_fn: function that computes the metric result. + # Returns + Decorated function that wraps `result()` with identity op. + """ + def decorated(metric_obj, *args, **kwargs): + result_t = K.identity(result_fn(*args, **kwargs)) + metric_obj._call_result = result_t + result_t._is_metric = True + return result_t + return decorated + + +def to_list(x): + if isinstance(x, list): + return x + return [x] + + +def assert_thresholds_range(thresholds): + if thresholds is not None: + invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] + if invalid_thresholds: + raise ValueError( + 'Threshold values must be in [0, 1]. Invalid values: {}'.format( + invalid_thresholds)) + + +def parse_init_thresholds(thresholds, default_threshold=0.5): + if thresholds is not None: + assert_thresholds_range(to_list(thresholds)) + thresholds = to_list(default_threshold if thresholds is None else thresholds) + return thresholds + +class ConfusionMatrix(Enum): + TRUE_POSITIVES = 'tp' + FALSE_POSITIVES = 'fp' + TRUE_NEGATIVES = 'tn' + FALSE_NEGATIVES = 'fn' + +class AUCCurve(Enum): + """Type of AUC Curve (ROC or PR).""" + ROC = 'ROC' + PR = 'PR' + + @staticmethod + def from_str(key): + if key in ('pr', 'PR'): + return AUCCurve.PR + elif key in ('roc', 'ROC'): + return AUCCurve.ROC + else: + raise ValueError('Invalid AUC curve value "%s".' % key) + + +class AUCSummationMethod(Enum): + """Type of AUC summation method. + + https://en.wikipedia.org/wiki/Riemann_sum) + + Contains the following values: + * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For + `PR` curve, interpolates (true/false) positives but not the ratio that is + precision (see Davis & Goadrich 2006 for details). + * 'minoring': Applies left summation for increasing intervals and right + summation for decreasing intervals. + * 'majoring': Applies right summation for increasing intervals and left + summation for decreasing intervals. + """ + INTERPOLATION = 'interpolation' + MAJORING = 'majoring' + MINORING = 'minoring' + + @staticmethod + def from_str(key): + if key in ('interpolation', 'Interpolation'): + return AUCSummationMethod.INTERPOLATION + elif key in ('majoring', 'Majoring'): + return AUCSummationMethod.MAJORING + elif key in ('minoring', 'Minoring'): + return AUCSummationMethod.MINORING + else: + raise ValueError('Invalid AUC summation method value "%s".' % key) + +def weighted_assign_add(label, pred, weights, var): + # Logical and + label = K.expand_dims(label, 0) + pred = K.expand_dims(pred, 0) + are_different = K.concatenate([label, pred], axis=0) + label_and_pred = K.all(are_different, axis=0) + label_and_pred = K.cast(label_and_pred, dtype=K.floatx()) + if weights is not None: + label_and_pred *= weights + return var.assign_add(K.sum(label_and_pred, 1)) + +def update_confusion_matrix_variables(variables_to_update, + y_true, + y_pred, + thresholds, + top_k=None, + class_id=None, + sample_weight=None): + """Returns op to update the given confusion matrix variables. + For every pair of values in y_true and y_pred: + true_positive: y_true == True and y_pred > thresholds + false_negatives: y_true == True and y_pred <= thresholds + true_negatives: y_true == False and y_pred <= thresholds + false_positive: y_true == False and y_pred > thresholds + The results will be weighted and added together. When multiple thresholds are + provided, we will repeat the same for every threshold. + For estimation of these metrics over a stream of data, the function creates an + `update_op` operation that updates the given variables. + If `sample_weight` is `None`, weights default to 1. + Use weights of 0 to mask values. + # Arguments + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys + and corresponding variables to update as values. + y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. + y_pred: A floating point `Tensor` of arbitrary shape and whose values are in + the range `[0, 1]`. + thresholds: A float value or a python list or tuple of float thresholds in + `[0, 1]`, or NEG_INF (used when top_k is set). + top_k: Optional int, indicates that the positive labels should be limited to + the top k predictions. + class_id: Optional int, limits the prediction and labels to the class + specified by this argument. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must + be either `1`, or the same as the corresponding `y_true` dimension). + # Returns + Update ops. + # Raises + ValueError: If `y_pred` and `y_true` have mismatched shapes, or if + `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if + `variables_to_update` contains invalid keys. + """ + if variables_to_update is None: + return + y_true = K.cast(y_true, dtype=K.floatx()) + y_pred = K.cast(y_pred, dtype=K.floatx()) + if sample_weight is not None: + sample_weight = K.cast(sample_weight, dtype=K.floatx()) + + if not any(key + for key in variables_to_update + if key in list(ConfusionMatrix)): + raise ValueError( + 'Please provide at least one valid confusion matrix ' + 'variable to update. Valid variable key options are: "{}". ' + 'Received: "{}"'.format( + list(ConfusionMatrix), variables_to_update.keys())) + + invalid_keys = [ + key for key in variables_to_update if key not in list(ConfusionMatrix) + ] + if invalid_keys: + raise ValueError( + 'Invalid keys: {}. Valid variable key options are: "{}"'.format( + invalid_keys, list(ConfusionMatrix))) + + if sample_weight is None: + y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( + y_pred, y_true=y_true) + else: + y_pred, y_true, sample_weight = ( + losses_utils.squeeze_or_expand_dimensions( + y_pred, y_true=y_true, sample_weight=sample_weight)) + + if top_k is not None: + y_pred = _filter_top_k(y_pred, top_k) + if class_id is not None: + y_true = y_true[..., class_id] + y_pred = y_pred[..., class_id] + + thresholds = to_list(thresholds) + num_thresholds = len(thresholds) + num_predictions = K.size(y_pred) + + # Reshape predictions and labels. + predictions_2d = K.reshape(y_pred, [1, -1]) + labels_2d = K.reshape( + K.cast(y_true, dtype='bool'), [1, -1]) + + # Tile the thresholds for every prediction. + thresh_tiled = K.tile( + K.expand_dims(K.constant(thresholds), 1), + K.stack([1, num_predictions])) + + # Tile the predictions for every threshold. + preds_tiled = K.tile(predictions_2d, [num_thresholds, 1]) + + # Compare predictions and threshold. + pred_is_pos = K.greater(preds_tiled, thresh_tiled) + pred_is_neg = K.greater(thresh_tiled, preds_tiled) + + # Tile labels by number of thresholds + label_is_pos = K.tile(labels_2d, [num_thresholds, 1]) + + if sample_weight is not None: + weights = losses_utils.broadcast_weights( + y_pred, K.cast(sample_weight, dtype=K.floatx())) + weights_tiled = K.tile( + K.reshape(weights, [1, -1]), [num_thresholds, 1]) + else: + weights_tiled = None + + update_ops = [] + loop_vars = { + ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), + } + update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update + update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update + update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update + + if update_fn or update_tn: + loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) + + if update_fp or update_tn: + label_is_neg = K.equal( + label_is_pos, K.zeros_like(label_is_pos, dtype=label_is_pos.dtype)) + loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) + if update_tn: + loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) + + for matrix_cond, (label, pred) in loop_vars.items(): + if matrix_cond in variables_to_update: + update_ops.append( + weighted_assign_add(label, pred, weights_tiled, + variables_to_update[matrix_cond])) + return update_ops + -- 2.26.2