|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133 |
- From 901159da45695da24a5206125910f02fc50169ce Mon Sep 17 00:00:00 2001
- From: Efraim Flashner <efraim@flashner.co.il>
- Date: Thu, 23 Apr 2020 15:50:37 +0300
- Subject: [PATCH] add keras metrics
-
- ---
- keras/backend/tensorflow_backend.py | 12 +
- keras/metrics.py | 584 ++++++++++++++++++++++++++++
- keras/utils/__init__.py | 2 +
- keras/utils/losses_utils.py | 177 +++++++++
- keras/utils/metrics_utils.py | 278 +++++++++++++
- 5 files changed, 1053 insertions(+)
- create mode 100644 keras/utils/losses_utils.py
- create mode 100644 keras/utils/metrics_utils.py
-
- diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py
- index bcb8be0..a2870f5 100644
- --- a/keras/backend/tensorflow_backend.py
- +++ b/keras/backend/tensorflow_backend.py
- @@ -4453,3 +4453,15 @@ def local_conv2d(inputs, kernel, kernel_size, strides, output_shape, data_format
- else:
- output = permute_dimensions(output, (2, 0, 1, 3))
- return output
- +
- +#get_graph = tf_keras_backend.get_graph
- +
- +#def is_symbolic(x):
- +# return isinstance(x, tf.Tensor) and hasattr(x, 'op')
- +
- +def size(x, name=None):
- +# if is_symbolic(x):
- +# with get_graph().as_default():
- +# return tf.size(x)
- + return tf.size(x, name=name)
- +
- diff --git a/keras/metrics.py b/keras/metrics.py
- index 8e3df1f..8f57910 100644
- --- a/keras/metrics.py
- +++ b/keras/metrics.py
- @@ -4,8 +4,12 @@ from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- +import abc
- import six
- +import types
- +
- from . import backend as K
- +from .engine.base_layer import Layer
- from .losses import mean_squared_error
- from .losses import mean_absolute_error
- from .losses import mean_absolute_percentage_error
- @@ -19,10 +23,201 @@ from .losses import binary_crossentropy
- from .losses import kullback_leibler_divergence
- from .losses import poisson
- from .losses import cosine_proximity
- +from .utils import losses_utils
- +from .utils import metrics_utils
- from .utils.generic_utils import deserialize_keras_object
- from .utils.generic_utils import serialize_keras_object
-
-
- +@six.add_metaclass(abc.ABCMeta)
- +class Metric(Layer):
- + """Encapsulates metric logic and state.
- +
- + Standalone usage:
- + ```python
- + m = SomeMetric(...)
- + for input in ...:
- + m.update_state(input)
- + m.result()
- + ```
- +
- + Usage with the `compile` API:
- + ```python
- + model.compile(optimizer='rmsprop',
- + loss=keras.losses.categorical_crossentropy,
- + metrics=[keras.metrics.CategoricalAccuracy()])
- + ```
- +
- + To be implemented by subclasses:
- + * `__init__()`: All state variables should be created in this method by
- + calling `self.add_weight()` like: `self.var = self.add_weight(...)`
- + * `update_state()`: Has all updates to the state variables like:
- + self.var.assign_add(...).
- + * `result()`: Computes and returns a value for the metric
- + from the state variables.
- + """
- +
- + def __init__(self, name=None, dtype=None, **kwargs):
- + super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
- + self.stateful = True # All metric layers are stateful.
- + self.built = True
- + self.dtype = K.floatx() if dtype is None else dtype
- +
- + def __new__(cls, *args, **kwargs):
- + obj = super(Metric, cls).__new__(cls)
- + update_state_fn = obj.update_state
- +
- + obj.update_state = types.MethodType(
- + metrics_utils.update_state_wrapper(update_state_fn), obj)
- + return obj
- +
- + def __call__(self, *args, **kwargs):
- + """Accumulates statistics and then computes metric result value."""
- + update_op = self.update_state(*args, **kwargs)
- + return self.result()
- +
- + def get_config(self):
- + """Returns the serializable config of the metric."""
- + return {'name': self.name, 'dtype': self.dtype}
- +
- + def reset_states(self):
- + """Resets all of the metric state variables.
- + This function is called between epochs/steps,
- + when a metric is evaluated during training.
- + """
- + K.batch_set_value([(v, 0) for v in self.weights])
- +
- + @abc.abstractmethod
- + def update_state(self, *args, **kwargs):
- + """Accumulates statistics for the metric. """
- + raise NotImplementedError('Must be implemented in subclasses.')
- +
- + @abc.abstractmethod
- + def result(self):
- + """Computes and returns the metric value tensor.
- + Result computation is an idempotent operation that simply calculates the
- + metric value using the state variables.
- + """
- + raise NotImplementedError('Must be implemented in subclasses.')
- +
- + # For use by subclasses #
- + def add_weight(self,
- + name,
- + shape=(),
- + initializer=None,
- + dtype=None):
- + """Adds state variable. Only for use by subclasses."""
- + return super(Metric, self).add_weight(
- + name=name,
- + shape=shape,
- + dtype=self.dtype if dtype is None else dtype,
- + trainable=False,
- + initializer=initializer)
- +
- + # End: For use by subclasses ###
- +
- +
- +class Reduce(Metric):
- + """Encapsulates metrics that perform a reduce operation on the values."""
- +
- + def __init__(self, reduction, name, dtype=None):
- + """Creates a `Reduce` instance.
- + # Arguments
- + reduction: a metrics `Reduction` enum value.
- + name: string name of the metric instance.
- + dtype: (Optional) data type of the metric result.
- + """
- + super(Reduce, self).__init__(name=name, dtype=dtype)
- + self.reduction = reduction
- + self.total = self.add_weight('total', initializer='zeros')
- + if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
- + metrics_utils.Reduction.WEIGHTED_MEAN]:
- + self.count = self.add_weight('count', initializer='zeros')
- +
- + def update_state(self, values, sample_weight=None):
- + """Accumulates statistics for computing the reduction metric.
- + For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE,
- + then the value of `result()` is 4. If the `sample_weight` is specified as
- + [1, 1, 0, 0] then value of `result()` would be 2.
- + # Arguments
- + values: Per-example value.
- + sample_weight: Optional weighting of each example. Defaults to 1.
- + """
- + values = K.cast(values, self.dtype)
- + if sample_weight is not None:
- + sample_weight = K.cast(sample_weight, self.dtype)
- + # Update dimensions of weights to match with values if possible.
- + values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
- + values, sample_weight=sample_weight)
- +
- + # Broadcast weights if possible.
- + sample_weight = losses_utils.broadcast_weights(sample_weight, values)
- + values = values * sample_weight
- +
- + value_sum = K.sum(values)
- + update_total_op = K.update_add(self.total, value_sum)
- +
- + # Exit early if the reduction doesn't have a denominator.
- + if self.reduction == metrics_utils.Reduction.SUM:
- + return update_total_op
- +
- + # Update `count` for reductions that require a denominator.
- + if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
- + num_values = K.cast(K.size(values), self.dtype)
- + elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
- + if sample_weight is None:
- + num_values = K.cast(K.size(values), self.dtype)
- + else:
- + num_values = K.sum(sample_weight)
- + else:
- + raise NotImplementedError(
- + 'reduction [%s] not implemented' % self.reduction)
- +
- + with K.control_dependencies([update_total_op]):
- + return K.update_add(self.count, num_values)
- +
- + def result(self):
- + if self.reduction == metrics_utils.Reduction.SUM:
- + return self.total
- + elif self.reduction in [
- + metrics_utils.Reduction.WEIGHTED_MEAN,
- + metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
- + ]:
- + return self.total / self.count
- + else:
- + raise NotImplementedError(
- + 'reduction [%s] not implemented' % self.reduction)
- +
- +
- +class Sum(Reduce):
- + """Computes the (weighted) sum of the given values.
- +
- + For example, if values is [1, 3, 5, 7] then the sum is 16.
- + If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
- +
- + This metric creates one variable, `total`, that is used to compute the sum of
- + `values`. This is ultimately returned as `sum`.
- + If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0
- + to mask values.
- +
- + Standalone usage:
- + ```python
- + m = keras.metrics.Sum()
- + m.update_state([1, 3, 5, 7])
- + m.result()
- + ```
- + """
- +
- + def __init__(self, name='sum', dtype=None):
- + """Creates a `Sum` instance.
- + # Arguments
- + name: (Optional) string name of the metric instance.
- + dtype: (Optional) data type of the metric result.
- + """
- + super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
- + name=name, dtype=dtype)
- +
- +
- def binary_accuracy(y_true, y_pred):
- return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
-
- @@ -49,6 +244,395 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
- axis=-1)
-
- +class SensitivitySpecificityBase(Metric):
- + """Abstract base class for computing sensitivity and specificity.
- +
- + For additional information about specificity and sensitivity, see the
- + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
- + """
- +
- + def __init__(self, value, num_thresholds=200, name=None, dtype=None):
- + super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
- + if num_thresholds <= 0:
- + raise ValueError('`num_thresholds` must be > 0.')
- + self.value = value
- + self.true_positives = self.add_weight(
- + 'true_positives',
- + shape=(num_thresholds,),
- + initializer='zeros')
- + self.true_negatives = self.add_weight(
- + 'true_negatives',
- + shape=(num_thresholds,),
- + initializer='zeros')
- + self.false_positives = self.add_weight(
- + 'false_positives',
- + shape=(num_thresholds,),
- + initializer='zeros')
- + self.false_negatives = self.add_weight(
- + 'false_negatives',
- + shape=(num_thresholds,),
- + initializer='zeros')
- +
- + # Compute `num_thresholds` thresholds in [0, 1]
- + if num_thresholds == 1:
- + self.thresholds = [0.5]
- + else:
- + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- + for i in range(num_thresholds - 2)]
- + self.thresholds = [0.0] + thresholds + [1.0]
- +
- + def update_state(self, y_true, y_pred, sample_weight=None):
- + return metrics_utils.update_confusion_matrix_variables(
- + {
- + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
- + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
- + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
- + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
- + },
- + y_true,
- + y_pred,
- + thresholds=self.thresholds,
- + sample_weight=sample_weight)
- +
- + def reset_states(self):
- + num_thresholds = len(self.thresholds)
- + K.batch_set_value(
- + [(v, np.zeros((num_thresholds,))) for v in self.variables])
- +
- +
- +class SensitivityAtSpecificity(SensitivitySpecificityBase):
- + """Computes the sensitivity at a given specificity.
- +
- + `Sensitivity` measures the proportion of actual positives that are correctly
- + identified as such (tp / (tp + fn)).
- + `Specificity` measures the proportion of actual negatives that are correctly
- + identified as such (tn / (tn + fp)).
- +
- + This metric creates four local variables, `true_positives`, `true_negatives`,
- + `false_positives` and `false_negatives` that are used to compute the
- + sensitivity at the given specificity. The threshold for the given specificity
- + value is computed and used to evaluate the corresponding sensitivity.
- +
- + If `sample_weight` is `None`, weights default to 1.
- + Use `sample_weight` of 0 to mask values.
- +
- + For additional information about specificity and sensitivity, see the
- + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
- +
- + Usage with the compile API:
- +
- + ```python
- + model = keras.Model(inputs, outputs)
- + model.compile(
- + 'sgd',
- + loss='mse',
- + metrics=[keras.metrics.SensitivityAtSpecificity()])
- + ```
- +
- + # Arguments
- + specificity: A scalar value in range `[0, 1]`.
- + num_thresholds: (Optional) Defaults to 200. The number of thresholds to
- + use for matching the given specificity.
- + name: (Optional) string name of the metric instance.
- + dtype: (Optional) data type of the metric result.
- + """
- +
- + def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
- + if specificity < 0 or specificity > 1:
- + raise ValueError('`specificity` must be in the range [0, 1].')
- + self.specificity = specificity
- + self.num_thresholds = num_thresholds
- + super(SensitivityAtSpecificity, self).__init__(
- + specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
- +
- + def result(self):
- + # Calculate specificities at all the thresholds.
- + specificities = K.switch(
- + K.greater(self.true_negatives + self.false_positives, 0),
- + (self.true_negatives / (self.true_negatives + self.false_positives)),
- + K.zeros_like(self.thresholds))
- +
- + # Find the index of the threshold where the specificity is closest to the
- + # given specificity.
- + min_index = K.argmin(
- + K.abs(specificities - self.value), axis=0)
- + min_index = K.cast(min_index, 'int32')
- +
- + # Compute sensitivity at that index.
- + return K.switch(
- + K.greater((self.true_positives[min_index] +
- + self.false_negatives[min_index]), 0),
- + (self.true_positives[min_index] /
- + (self.true_positives[min_index] + self.false_negatives[min_index])),
- + K.zeros_like(self.true_positives[min_index]))
- +
- + def get_config(self):
- + config = {
- + 'num_thresholds': self.num_thresholds,
- + 'specificity': self.specificity
- + }
- + base_config = super(SensitivityAtSpecificity, self).get_config()
- + return dict(list(base_config.items()) + list(config.items()))
- +
- +
- +class AUC(Metric):
- + """Computes the approximate AUC (Area under the curve) via a Riemann sum.
- +
- + This metric creates four local variables, `true_positives`, `true_negatives`,
- + `false_positives` and `false_negatives` that are used to compute the AUC.
- + To discretize the AUC curve, a linearly spaced set of thresholds is used to
- + compute pairs of recall and precision values. The area under the ROC-curve is
- + therefore computed using the height of the recall values by the false positive
- + rate, while the area under the PR-curve is the computed using the height of
- + the precision values by the recall.
- +
- + This value is ultimately returned as `auc`, an idempotent operation that
- + computes the area under a discretized curve of precision versus recall values
- + (computed using the aforementioned variables). The `num_thresholds` variable
- + controls the degree of discretization with larger numbers of thresholds more
- + closely approximating the true AUC. The quality of the approximation may vary
- + dramatically depending on `num_thresholds`. The `thresholds` parameter can be
- + used to manually specify thresholds which split the predictions more evenly.
- +
- + For best results, `predictions` should be distributed approximately uniformly
- + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
- + approximation may be poor if this is not the case. Setting `summation_method`
- + to 'minoring' or 'majoring' can help quantify the error in the approximation
- + by providing lower or upper bound estimate of the AUC.
- +
- + If `sample_weight` is `None`, weights default to 1.
- + Use `sample_weight` of 0 to mask values.
- +
- + Usage with the compile API:
- +
- + ```python
- + model = keras.Model(inputs, outputs)
- + model.compile('sgd', loss='mse', metrics=[keras.metrics.AUC()])
- + ```
- +
- + # Arguments
- + num_thresholds: (Optional) Defaults to 200. The number of thresholds to
- + use when discretizing the roc curve. Values must be > 1.
- + curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
- + [default] or 'PR' for the Precision-Recall-curve.
- + summation_method: (Optional) Specifies the Riemann summation method used
- + (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
- + applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
- + (true/false) positives but not the ratio that is precision (see Davis
- + & Goadrich 2006 for details); 'minoring' that applies left summation
- + for increasing intervals and right summation for decreasing intervals;
- + 'majoring' that does the opposite.
- + name: (Optional) string name of the metric instance.
- + dtype: (Optional) data type of the metric result.
- + thresholds: (Optional) A list of floating point values to use as the
- + thresholds for discretizing the curve. If set, the `num_thresholds`
- + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
- + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
- + be automatically included with these to correctly handle predictions
- + equal to exactly 0 or 1.
- + """
- +
- + def __init__(self,
- + num_thresholds=200,
- + curve='ROC',
- + summation_method='interpolation',
- + name=None,
- + dtype=None,
- + thresholds=None):
- + # Validate configurations.
- + if (isinstance(curve, metrics_utils.AUCCurve) and
- + curve not in list(metrics_utils.AUCCurve)):
- + raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
- + curve, list(metrics_utils.AUCCurve)))
- + if isinstance(
- + summation_method,
- + metrics_utils.AUCSummationMethod) and summation_method not in list(
- + metrics_utils.AUCSummationMethod):
- + raise ValueError(
- + 'Invalid summation method: "{}". Valid options are: "{}"'.format(
- + summation_method, list(metrics_utils.AUCSummationMethod)))
- +
- + # Update properties.
- + if thresholds is not None:
- + # If specified, use the supplied thresholds.
- + self.num_thresholds = len(thresholds) + 2
- + thresholds = sorted(thresholds)
- + else:
- + if num_thresholds <= 1:
- + raise ValueError('`num_thresholds` must be > 1.')
- +
- + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
- + # (0, 1).
- + self.num_thresholds = num_thresholds
- + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- + for i in range(num_thresholds - 2)]
- +
- + # Add an endpoint "threshold" below zero and above one for either
- + # threshold method to account for floating point imprecisions.
- + self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
- +
- + if isinstance(curve, metrics_utils.AUCCurve):
- + self.curve = curve
- + else:
- + self.curve = metrics_utils.AUCCurve.from_str(curve)
- + if isinstance(summation_method, metrics_utils.AUCSummationMethod):
- + self.summation_method = summation_method
- + else:
- + self.summation_method = metrics_utils.AUCSummationMethod.from_str(
- + summation_method)
- + super(AUC, self).__init__(name=name, dtype=dtype)
- +
- + # Create metric variables
- + self.true_positives = self.add_weight(
- + 'true_positives',
- + shape=(self.num_thresholds,),
- + initializer='zeros')
- + self.true_negatives = self.add_weight(
- + 'true_negatives',
- + shape=(self.num_thresholds,),
- + initializer='zeros')
- + self.false_positives = self.add_weight(
- + 'false_positives',
- + shape=(self.num_thresholds,),
- + initializer='zeros')
- + self.false_negatives = self.add_weight(
- + 'false_negatives',
- + shape=(self.num_thresholds,),
- + initializer='zeros')
- +
- + def update_state(self, y_true, y_pred, sample_weight=None):
- + return metrics_utils.update_confusion_matrix_variables({
- + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
- + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
- + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
- + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
- + }, y_true, y_pred, self.thresholds, sample_weight=sample_weight)
- +
- + def interpolate_pr_auc(self):
- + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
- +
- + https://www.biostat.wisc.edu/~page/rocpr.pdf
- +
- + Note here we derive & use a closed formula not present in the paper
- + as follows:
- +
- + Precision = TP / (TP + FP) = TP / P
- +
- + Modeling all of TP (true positive), FP (false positive) and their sum
- + P = TP + FP (predicted positive) as varying linearly within each interval
- + [A, B] between successive thresholds, we get
- +
- + Precision slope = dTP / dP
- + = (TP_B - TP_A) / (P_B - P_A)
- + = (TP - TP_A) / (P - P_A)
- + Precision = (TP_A + slope * (P - P_A)) / P
- +
- + The area within the interval is (slope / total_pos_weight) times
- +
- + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
- + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
- +
- + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
- +
- + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
- +
- + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
- +
- + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
- +
- + where dTP == TP_B - TP_A.
- +
- + Note that when P_A == 0 the above calculation simplifies into
- +
- + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
- +
- + which is really equivalent to imputing constant precision throughout the
- + first bucket having >0 true positives.
- +
- + # Returns
- + pr_auc: an approximation of the area under the P-R curve.
- + """
- + dtp = self.true_positives[:self.num_thresholds -
- + 1] - self.true_positives[1:]
- + p = self.true_positives + self.false_positives
- + dp = p[:self.num_thresholds - 1] - p[1:]
- +
- + prec_slope = dtp / K.maximum(dp, 0)
- + intercept = self.true_positives[1:] - (prec_slope * p[1:])
- +
- + # Logical and
- + pMin = K.expand_dims(p[:self.num_thresholds - 1] > 0, 0)
- + pMax = K.expand_dims(p[1:] > 0, 0)
- + are_different = K.concatenate([pMin, pMax], axis=0)
- + switch_condition = K.all(are_different, axis=0)
- +
- + safe_p_ratio = K.switch(
- + switch_condition,
- + p[:self.num_thresholds - 1] / K.maximum(p[1:], 0),
- + K.ones_like(p[1:]))
- +
- + numer = prec_slope * (dtp + intercept * K.log(safe_p_ratio))
- + denom = K.maximum(self.true_positives[1:] + self.false_negatives[1:], 0)
- + return K.sum((numer / denom))
- +
- + def result(self):
- + if (self.curve == metrics_utils.AUCCurve.PR and
- + (self.summation_method ==
- + metrics_utils.AUCSummationMethod.INTERPOLATION)):
- + # This use case is different and is handled separately.
- + return self.interpolate_pr_auc()
- +
- + # Set `x` and `y` values for the curves based on `curve` config.
- + recall = K.switch(
- + K.greater((self.true_positives), 0),
- + (self.true_positives /
- + (self.true_positives + self.false_negatives)),
- + K.zeros_like(self.true_positives))
- + if self.curve == metrics_utils.AUCCurve.ROC:
- + fp_rate = K.switch(
- + K.greater((self.false_positives), 0),
- + (self.false_positives /
- + (self.false_positives + self.true_negatives)),
- + K.zeros_like(self.false_positives))
- + x = fp_rate
- + y = recall
- + else: # curve == 'PR'.
- + precision = K.switch(
- + K.greater((self.true_positives), 0),
- + (self.true_positives / (self.true_positives + self.false_positives)),
- + K.zeros_like(self.true_positives))
- + x = recall
- + y = precision
- +
- + # Find the rectangle heights based on `summation_method`.
- + if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
- + # Note: the case ('PR', 'interpolation') has been handled above.
- + heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
- + elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
- + heights = K.minimum(y[:self.num_thresholds - 1], y[1:])
- + else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
- + heights = K.maximum(y[:self.num_thresholds - 1], y[1:])
- +
- + # Sum up the areas of all the rectangles.
- + return K.sum((x[:self.num_thresholds - 1] - x[1:]) * heights)
- +
- + def reset_states(self):
- + K.batch_set_value(
- + [(v, np.zeros((self.num_thresholds,))) for v in self.variables])
- +
- + def get_config(self):
- + config = {
- + 'num_thresholds': self.num_thresholds,
- + 'curve': self.curve.value,
- + 'summation_method': self.summation_method.value,
- + # We remove the endpoint thresholds as an inverse of how the thresholds
- + # were initialized. This ensures that a metric initialized from this
- + # config has the same thresholds.
- + 'thresholds': self.thresholds[1:-1],
- + }
- + base_config = super(AUC, self).get_config()
- + return dict(list(base_config.items()) + list(config.items()))
- +
-
- # Aliases
-
- diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py
- index 8cc39d5..65af329 100644
- --- a/keras/utils/__init__.py
- +++ b/keras/utils/__init__.py
- @@ -4,6 +4,8 @@ from . import generic_utils
- from . import data_utils
- from . import io_utils
- from . import conv_utils
- +from . import losses_utils
- +from . import metrics_utils
-
- # Globally-importable utils.
- from .io_utils import HDF5Matrix
- diff --git a/keras/utils/losses_utils.py b/keras/utils/losses_utils.py
- new file mode 100644
- index 0000000..617ebb7
- --- /dev/null
- +++ b/keras/utils/losses_utils.py
- @@ -0,0 +1,177 @@
- +"""Utilities related to losses."""
- +from __future__ import absolute_import
- +from __future__ import division
- +from __future__ import print_function
- +
- +import numpy as np
- +
- +from .. import backend as K
- +
- +
- +class Reduction(object):
- + """Types of loss reduction.
- +
- + Contains the following values:
- +
- + * `NONE`: Un-reduced weighted losses with the same shape as input. When this
- + reduction type used with built-in Keras training loops like
- + `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but
- + the reported loss will be a scalar value.
- + * `SUM`: Scalar sum of weighted losses.
- + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
- + """
- +
- + NONE = 'none'
- + SUM = 'sum'
- + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
- +
- + @classmethod
- + def all(cls):
- + return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
- +
- + @classmethod
- + def validate(cls, key):
- + if key not in cls.all():
- + raise ValueError('Invalid Reduction Key %s.' % key)
- +
- +
- +def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
- + """Squeeze or expand last dimension if needed.
- +
- + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1.
- + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
- + from the new rank of `y_pred`.
- + If `sample_weight` is scalar, it is kept scalar.
- +
- + # Arguments
- + y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
- + y_true: Optional label `Tensor` whose dimensions match `y_pred`.
- + sample_weight: Optional weight scalar or `Tensor` whose dimensions match
- + `y_pred`.
- +
- + # Returns
- + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
- + the last dimension squeezed, `sample_weight` could be extended by one
- + dimension.
- + """
- + if y_true is not None:
- + y_pred_rank = K.ndim(y_pred)
- + y_pred_shape = K.int_shape(y_pred)
- + y_true_rank = K.ndim(y_true)
- + y_true_shape = K.int_shape(y_true)
- +
- + if (y_pred_rank - y_true_rank == 1) and (y_pred_shape[-1] == 1):
- + y_pred = K.squeeze(y_pred, -1)
- + elif (y_true_rank - y_pred_rank == 1) and (y_true_shape[-1] == 1):
- + y_true = K.squeeze(y_true, -1)
- +
- + if sample_weight is None:
- + return y_pred, y_true
- +
- + y_pred_rank = K.ndim(y_pred)
- + weights_rank = K.ndim(sample_weight)
- + if weights_rank != 0:
- + if weights_rank - y_pred_rank == 1:
- + sample_weight = K.squeeze(sample_weight, -1)
- + elif y_pred_rank - weights_rank == 1:
- + sample_weight = K.expand_dims(sample_weight, -1)
- + return y_pred, y_true, sample_weight
- +
- +
- +def _num_elements(losses):
- + """Computes the number of elements in `losses` tensor."""
- + with K.name_scope('num_elements') as scope:
- + return K.cast(K.size(losses, name=scope), losses.dtype)
- +
- +
- +def reduce_weighted_loss(weighted_losses, reduction=Reduction.SUM_OVER_BATCH_SIZE):
- + """Reduces the individual weighted loss measurements."""
- + if reduction == Reduction.NONE:
- + loss = weighted_losses
- + else:
- + loss = K.sum(weighted_losses)
- + if reduction == Reduction.SUM_OVER_BATCH_SIZE:
- + loss = loss / _num_elements(weighted_losses)
- + return loss
- +
- +
- +def broadcast_weights(values, sample_weight):
- + # Broadcast weights if possible.
- + weights_shape = K.int_shape(sample_weight)
- + values_shape = K.int_shape(values)
- +
- + if values_shape != weights_shape:
- + weights_rank = K.ndim(sample_weight)
- + values_rank = K.ndim(values)
- +
- + # Raise error if ndim of weights is > values.
- + if weights_rank > values_rank:
- + raise ValueError(
- + 'Incompatible shapes: `values` {} vs `sample_weight` {}'.format(
- + values_shape, weights_shape))
- +
- + # Expand dim of weights to match ndim of values, if required.
- + for i in range(weights_rank, values_rank):
- + sample_weight = K.expand_dims(sample_weight, axis=i)
- +
- + if weights_shape is not None and values_shape is not None:
- + for i in range(weights_rank):
- + if (weights_shape[i] is not None and
- + values_shape[i] is not None and
- + weights_shape[i] != values_shape[i]):
- + # Cannot be broadcasted.
- + if weights_shape[i] != 1:
- + raise ValueError(
- + 'Incompatible shapes: `values` {} vs '
- + '`sample_weight` {}'.format(
- + values_shape, weights_shape))
- + sample_weight = K.repeat_elements(
- + sample_weight, values_shape[i], axis=i)
- + return sample_weight
- +
- +
- +def compute_weighted_loss(losses,
- + sample_weight=None,
- + reduction=Reduction.SUM_OVER_BATCH_SIZE,
- + name=None):
- + """Computes the weighted loss.
- +
- + # Arguments
- + losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
- + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
- + ` losses`, or be broadcastable to `losses`.
- + reduction: (Optional) Type of Reduction to apply to loss.
- + Default value is `SUM_OVER_BATCH_SIZE`.
- + name: Optional name for the op.
- +
- + # Raises
- + ValueError: If the shape of `sample_weight` is not compatible with `losses`.
- +
- + # Returns
- + Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
- + `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
- + """
- + Reduction.validate(reduction)
- + if sample_weight is None:
- + sample_weight = 1.0
- + with K.name_scope(name or 'weighted_loss'):
- + input_dtype = K.dtype(losses)
- + losses = K.cast(losses, K.floatx())
- + sample_weight = K.cast(sample_weight, K.floatx())
- +
- + # Update dimensions of `sample_weight` to match with `losses` if possible.
- + losses, _, sample_weight = squeeze_or_expand_dimensions(
- + losses, None, sample_weight)
- +
- + # Broadcast weights if possible.
- + sample_weight = broadcast_weights(losses, sample_weight)
- +
- + # Apply weights to losses.
- + weighted_losses = sample_weight * losses
- +
- + # Apply reduction function to the individual weighted losses.
- + loss = reduce_weighted_loss(weighted_losses, reduction)
- + # Convert the result back to the input type.
- + loss = K.cast(loss, input_dtype)
- + return loss
- +
- diff --git a/keras/utils/metrics_utils.py b/keras/utils/metrics_utils.py
- new file mode 100644
- index 0000000..e6a5bb0
- --- /dev/null
- +++ b/keras/utils/metrics_utils.py
- @@ -0,0 +1,278 @@
- +"""Utilities related to metrics."""
- +from __future__ import absolute_import
- +from __future__ import division
- +from __future__ import print_function
- +
- +from enum import Enum
- +
- +from .. import backend as K
- +from . import losses_utils
- +
- +NEG_INF = -1e10
- +
- +class Reduction(object):
- + """Types of metrics reduction.
- + Contains the following values:
- + * `SUM`: Scalar sum of weighted values.
- + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` of weighted values divided by
- + number of elements in values.
- + * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
- + """
- +
- + SUM = 'sum'
- + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
- + WEIGHTED_MEAN = 'weighted_mean'
- +
- +
- +def update_state_wrapper(update_state_fn):
- + """Decorator to wrap metric `update_state()` with `add_update()`.
- + # Arguments
- + update_state_fn: function that accumulates metric statistics.
- + # Returns
- + Decorated function that wraps `update_state_fn()` with `add_update()`.
- + """
- + def decorated(metric_obj, *args, **kwargs):
- + """Decorated function with `add_update()`."""
- +
- + update_op = update_state_fn(*args, **kwargs)
- + metric_obj.add_update(update_op)
- + return update_op
- +
- + return decorated
- +
- +def result_wrapper(result_fn):
- + """Decorator to wrap metric `result()` with identity op.
- + Wrapping result in identity so that control dependency between
- + update_op from `update_state` and result works in case result returns
- + a tensor.
- + # Arguments
- + result_fn: function that computes the metric result.
- + # Returns
- + Decorated function that wraps `result()` with identity op.
- + """
- + def decorated(metric_obj, *args, **kwargs):
- + result_t = K.identity(result_fn(*args, **kwargs))
- + metric_obj._call_result = result_t
- + result_t._is_metric = True
- + return result_t
- + return decorated
- +
- +
- +def to_list(x):
- + if isinstance(x, list):
- + return x
- + return [x]
- +
- +
- +def assert_thresholds_range(thresholds):
- + if thresholds is not None:
- + invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
- + if invalid_thresholds:
- + raise ValueError(
- + 'Threshold values must be in [0, 1]. Invalid values: {}'.format(
- + invalid_thresholds))
- +
- +
- +def parse_init_thresholds(thresholds, default_threshold=0.5):
- + if thresholds is not None:
- + assert_thresholds_range(to_list(thresholds))
- + thresholds = to_list(default_threshold if thresholds is None else thresholds)
- + return thresholds
- +
- +class ConfusionMatrix(Enum):
- + TRUE_POSITIVES = 'tp'
- + FALSE_POSITIVES = 'fp'
- + TRUE_NEGATIVES = 'tn'
- + FALSE_NEGATIVES = 'fn'
- +
- +class AUCCurve(Enum):
- + """Type of AUC Curve (ROC or PR)."""
- + ROC = 'ROC'
- + PR = 'PR'
- +
- + @staticmethod
- + def from_str(key):
- + if key in ('pr', 'PR'):
- + return AUCCurve.PR
- + elif key in ('roc', 'ROC'):
- + return AUCCurve.ROC
- + else:
- + raise ValueError('Invalid AUC curve value "%s".' % key)
- +
- +
- +class AUCSummationMethod(Enum):
- + """Type of AUC summation method.
- +
- + https://en.wikipedia.org/wiki/Riemann_sum)
- +
- + Contains the following values:
- + * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
- + `PR` curve, interpolates (true/false) positives but not the ratio that is
- + precision (see Davis & Goadrich 2006 for details).
- + * 'minoring': Applies left summation for increasing intervals and right
- + summation for decreasing intervals.
- + * 'majoring': Applies right summation for increasing intervals and left
- + summation for decreasing intervals.
- + """
- + INTERPOLATION = 'interpolation'
- + MAJORING = 'majoring'
- + MINORING = 'minoring'
- +
- + @staticmethod
- + def from_str(key):
- + if key in ('interpolation', 'Interpolation'):
- + return AUCSummationMethod.INTERPOLATION
- + elif key in ('majoring', 'Majoring'):
- + return AUCSummationMethod.MAJORING
- + elif key in ('minoring', 'Minoring'):
- + return AUCSummationMethod.MINORING
- + else:
- + raise ValueError('Invalid AUC summation method value "%s".' % key)
- +
- +def weighted_assign_add(label, pred, weights, var):
- + # Logical and
- + label = K.expand_dims(label, 0)
- + pred = K.expand_dims(pred, 0)
- + are_different = K.concatenate([label, pred], axis=0)
- + label_and_pred = K.all(are_different, axis=0)
- + label_and_pred = K.cast(label_and_pred, dtype=K.floatx())
- + if weights is not None:
- + label_and_pred *= weights
- + return var.assign_add(K.sum(label_and_pred, 1))
- +
- +def update_confusion_matrix_variables(variables_to_update,
- + y_true,
- + y_pred,
- + thresholds,
- + top_k=None,
- + class_id=None,
- + sample_weight=None):
- + """Returns op to update the given confusion matrix variables.
- + For every pair of values in y_true and y_pred:
- + true_positive: y_true == True and y_pred > thresholds
- + false_negatives: y_true == True and y_pred <= thresholds
- + true_negatives: y_true == False and y_pred <= thresholds
- + false_positive: y_true == False and y_pred > thresholds
- + The results will be weighted and added together. When multiple thresholds are
- + provided, we will repeat the same for every threshold.
- + For estimation of these metrics over a stream of data, the function creates an
- + `update_op` operation that updates the given variables.
- + If `sample_weight` is `None`, weights default to 1.
- + Use weights of 0 to mask values.
- + # Arguments
- + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
- + and corresponding variables to update as values.
- + y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
- + y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
- + the range `[0, 1]`.
- + thresholds: A float value or a python list or tuple of float thresholds in
- + `[0, 1]`, or NEG_INF (used when top_k is set).
- + top_k: Optional int, indicates that the positive labels should be limited to
- + the top k predictions.
- + class_id: Optional int, limits the prediction and labels to the class
- + specified by this argument.
- + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
- + `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
- + be either `1`, or the same as the corresponding `y_true` dimension).
- + # Returns
- + Update ops.
- + # Raises
- + ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
- + `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
- + `variables_to_update` contains invalid keys.
- + """
- + if variables_to_update is None:
- + return
- + y_true = K.cast(y_true, dtype=K.floatx())
- + y_pred = K.cast(y_pred, dtype=K.floatx())
- + if sample_weight is not None:
- + sample_weight = K.cast(sample_weight, dtype=K.floatx())
- +
- + if not any(key
- + for key in variables_to_update
- + if key in list(ConfusionMatrix)):
- + raise ValueError(
- + 'Please provide at least one valid confusion matrix '
- + 'variable to update. Valid variable key options are: "{}". '
- + 'Received: "{}"'.format(
- + list(ConfusionMatrix), variables_to_update.keys()))
- +
- + invalid_keys = [
- + key for key in variables_to_update if key not in list(ConfusionMatrix)
- + ]
- + if invalid_keys:
- + raise ValueError(
- + 'Invalid keys: {}. Valid variable key options are: "{}"'.format(
- + invalid_keys, list(ConfusionMatrix)))
- +
- + if sample_weight is None:
- + y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
- + y_pred, y_true=y_true)
- + else:
- + y_pred, y_true, sample_weight = (
- + losses_utils.squeeze_or_expand_dimensions(
- + y_pred, y_true=y_true, sample_weight=sample_weight))
- +
- + if top_k is not None:
- + y_pred = _filter_top_k(y_pred, top_k)
- + if class_id is not None:
- + y_true = y_true[..., class_id]
- + y_pred = y_pred[..., class_id]
- +
- + thresholds = to_list(thresholds)
- + num_thresholds = len(thresholds)
- + num_predictions = K.size(y_pred)
- +
- + # Reshape predictions and labels.
- + predictions_2d = K.reshape(y_pred, [1, -1])
- + labels_2d = K.reshape(
- + K.cast(y_true, dtype='bool'), [1, -1])
- +
- + # Tile the thresholds for every prediction.
- + thresh_tiled = K.tile(
- + K.expand_dims(K.constant(thresholds), 1),
- + K.stack([1, num_predictions]))
- +
- + # Tile the predictions for every threshold.
- + preds_tiled = K.tile(predictions_2d, [num_thresholds, 1])
- +
- + # Compare predictions and threshold.
- + pred_is_pos = K.greater(preds_tiled, thresh_tiled)
- + pred_is_neg = K.greater(thresh_tiled, preds_tiled)
- +
- + # Tile labels by number of thresholds
- + label_is_pos = K.tile(labels_2d, [num_thresholds, 1])
- +
- + if sample_weight is not None:
- + weights = losses_utils.broadcast_weights(
- + y_pred, K.cast(sample_weight, dtype=K.floatx()))
- + weights_tiled = K.tile(
- + K.reshape(weights, [1, -1]), [num_thresholds, 1])
- + else:
- + weights_tiled = None
- +
- + update_ops = []
- + loop_vars = {
- + ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
- + }
- + update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
- + update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
- + update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
- +
- + if update_fn or update_tn:
- + loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
- +
- + if update_fp or update_tn:
- + label_is_neg = K.equal(
- + label_is_pos, K.zeros_like(label_is_pos, dtype=label_is_pos.dtype))
- + loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
- + if update_tn:
- + loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
- +
- + for matrix_cond, (label, pred) in loop_vars.items():
- + if matrix_cond in variables_to_update:
- + update_ops.append(
- + weighted_assign_add(label, pred, weights_tiled,
- + variables_to_update[matrix_cond]))
- + return update_ops
- +
- --
- 2.26.2
-
|