You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1134 lines
44 KiB

  1. From 901159da45695da24a5206125910f02fc50169ce Mon Sep 17 00:00:00 2001
  2. From: Efraim Flashner <efraim@flashner.co.il>
  3. Date: Thu, 23 Apr 2020 15:50:37 +0300
  4. Subject: [PATCH] add keras metrics
  5. ---
  6. keras/backend/tensorflow_backend.py | 12 +
  7. keras/metrics.py | 584 ++++++++++++++++++++++++++++
  8. keras/utils/__init__.py | 2 +
  9. keras/utils/losses_utils.py | 177 +++++++++
  10. keras/utils/metrics_utils.py | 278 +++++++++++++
  11. 5 files changed, 1053 insertions(+)
  12. create mode 100644 keras/utils/losses_utils.py
  13. create mode 100644 keras/utils/metrics_utils.py
  14. diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py
  15. index bcb8be0..a2870f5 100644
  16. --- a/keras/backend/tensorflow_backend.py
  17. +++ b/keras/backend/tensorflow_backend.py
  18. @@ -4453,3 +4453,15 @@ def local_conv2d(inputs, kernel, kernel_size, strides, output_shape, data_format
  19. else:
  20. output = permute_dimensions(output, (2, 0, 1, 3))
  21. return output
  22. +
  23. +#get_graph = tf_keras_backend.get_graph
  24. +
  25. +#def is_symbolic(x):
  26. +# return isinstance(x, tf.Tensor) and hasattr(x, 'op')
  27. +
  28. +def size(x, name=None):
  29. +# if is_symbolic(x):
  30. +# with get_graph().as_default():
  31. +# return tf.size(x)
  32. + return tf.size(x, name=name)
  33. +
  34. diff --git a/keras/metrics.py b/keras/metrics.py
  35. index 8e3df1f..8f57910 100644
  36. --- a/keras/metrics.py
  37. +++ b/keras/metrics.py
  38. @@ -4,8 +4,12 @@ from __future__ import absolute_import
  39. from __future__ import division
  40. from __future__ import print_function
  41. +import abc
  42. import six
  43. +import types
  44. +
  45. from . import backend as K
  46. +from .engine.base_layer import Layer
  47. from .losses import mean_squared_error
  48. from .losses import mean_absolute_error
  49. from .losses import mean_absolute_percentage_error
  50. @@ -19,10 +23,201 @@ from .losses import binary_crossentropy
  51. from .losses import kullback_leibler_divergence
  52. from .losses import poisson
  53. from .losses import cosine_proximity
  54. +from .utils import losses_utils
  55. +from .utils import metrics_utils
  56. from .utils.generic_utils import deserialize_keras_object
  57. from .utils.generic_utils import serialize_keras_object
  58. +@six.add_metaclass(abc.ABCMeta)
  59. +class Metric(Layer):
  60. + """Encapsulates metric logic and state.
  61. +
  62. + Standalone usage:
  63. + ```python
  64. + m = SomeMetric(...)
  65. + for input in ...:
  66. + m.update_state(input)
  67. + m.result()
  68. + ```
  69. +
  70. + Usage with the `compile` API:
  71. + ```python
  72. + model.compile(optimizer='rmsprop',
  73. + loss=keras.losses.categorical_crossentropy,
  74. + metrics=[keras.metrics.CategoricalAccuracy()])
  75. + ```
  76. +
  77. + To be implemented by subclasses:
  78. + * `__init__()`: All state variables should be created in this method by
  79. + calling `self.add_weight()` like: `self.var = self.add_weight(...)`
  80. + * `update_state()`: Has all updates to the state variables like:
  81. + self.var.assign_add(...).
  82. + * `result()`: Computes and returns a value for the metric
  83. + from the state variables.
  84. + """
  85. +
  86. + def __init__(self, name=None, dtype=None, **kwargs):
  87. + super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
  88. + self.stateful = True # All metric layers are stateful.
  89. + self.built = True
  90. + self.dtype = K.floatx() if dtype is None else dtype
  91. +
  92. + def __new__(cls, *args, **kwargs):
  93. + obj = super(Metric, cls).__new__(cls)
  94. + update_state_fn = obj.update_state
  95. +
  96. + obj.update_state = types.MethodType(
  97. + metrics_utils.update_state_wrapper(update_state_fn), obj)
  98. + return obj
  99. +
  100. + def __call__(self, *args, **kwargs):
  101. + """Accumulates statistics and then computes metric result value."""
  102. + update_op = self.update_state(*args, **kwargs)
  103. + return self.result()
  104. +
  105. + def get_config(self):
  106. + """Returns the serializable config of the metric."""
  107. + return {'name': self.name, 'dtype': self.dtype}
  108. +
  109. + def reset_states(self):
  110. + """Resets all of the metric state variables.
  111. + This function is called between epochs/steps,
  112. + when a metric is evaluated during training.
  113. + """
  114. + K.batch_set_value([(v, 0) for v in self.weights])
  115. +
  116. + @abc.abstractmethod
  117. + def update_state(self, *args, **kwargs):
  118. + """Accumulates statistics for the metric. """
  119. + raise NotImplementedError('Must be implemented in subclasses.')
  120. +
  121. + @abc.abstractmethod
  122. + def result(self):
  123. + """Computes and returns the metric value tensor.
  124. + Result computation is an idempotent operation that simply calculates the
  125. + metric value using the state variables.
  126. + """
  127. + raise NotImplementedError('Must be implemented in subclasses.')
  128. +
  129. + # For use by subclasses #
  130. + def add_weight(self,
  131. + name,
  132. + shape=(),
  133. + initializer=None,
  134. + dtype=None):
  135. + """Adds state variable. Only for use by subclasses."""
  136. + return super(Metric, self).add_weight(
  137. + name=name,
  138. + shape=shape,
  139. + dtype=self.dtype if dtype is None else dtype,
  140. + trainable=False,
  141. + initializer=initializer)
  142. +
  143. + # End: For use by subclasses ###
  144. +
  145. +
  146. +class Reduce(Metric):
  147. + """Encapsulates metrics that perform a reduce operation on the values."""
  148. +
  149. + def __init__(self, reduction, name, dtype=None):
  150. + """Creates a `Reduce` instance.
  151. + # Arguments
  152. + reduction: a metrics `Reduction` enum value.
  153. + name: string name of the metric instance.
  154. + dtype: (Optional) data type of the metric result.
  155. + """
  156. + super(Reduce, self).__init__(name=name, dtype=dtype)
  157. + self.reduction = reduction
  158. + self.total = self.add_weight('total', initializer='zeros')
  159. + if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
  160. + metrics_utils.Reduction.WEIGHTED_MEAN]:
  161. + self.count = self.add_weight('count', initializer='zeros')
  162. +
  163. + def update_state(self, values, sample_weight=None):
  164. + """Accumulates statistics for computing the reduction metric.
  165. + For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE,
  166. + then the value of `result()` is 4. If the `sample_weight` is specified as
  167. + [1, 1, 0, 0] then value of `result()` would be 2.
  168. + # Arguments
  169. + values: Per-example value.
  170. + sample_weight: Optional weighting of each example. Defaults to 1.
  171. + """
  172. + values = K.cast(values, self.dtype)
  173. + if sample_weight is not None:
  174. + sample_weight = K.cast(sample_weight, self.dtype)
  175. + # Update dimensions of weights to match with values if possible.
  176. + values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
  177. + values, sample_weight=sample_weight)
  178. +
  179. + # Broadcast weights if possible.
  180. + sample_weight = losses_utils.broadcast_weights(sample_weight, values)
  181. + values = values * sample_weight
  182. +
  183. + value_sum = K.sum(values)
  184. + update_total_op = K.update_add(self.total, value_sum)
  185. +
  186. + # Exit early if the reduction doesn't have a denominator.
  187. + if self.reduction == metrics_utils.Reduction.SUM:
  188. + return update_total_op
  189. +
  190. + # Update `count` for reductions that require a denominator.
  191. + if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
  192. + num_values = K.cast(K.size(values), self.dtype)
  193. + elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
  194. + if sample_weight is None:
  195. + num_values = K.cast(K.size(values), self.dtype)
  196. + else:
  197. + num_values = K.sum(sample_weight)
  198. + else:
  199. + raise NotImplementedError(
  200. + 'reduction [%s] not implemented' % self.reduction)
  201. +
  202. + with K.control_dependencies([update_total_op]):
  203. + return K.update_add(self.count, num_values)
  204. +
  205. + def result(self):
  206. + if self.reduction == metrics_utils.Reduction.SUM:
  207. + return self.total
  208. + elif self.reduction in [
  209. + metrics_utils.Reduction.WEIGHTED_MEAN,
  210. + metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
  211. + ]:
  212. + return self.total / self.count
  213. + else:
  214. + raise NotImplementedError(
  215. + 'reduction [%s] not implemented' % self.reduction)
  216. +
  217. +
  218. +class Sum(Reduce):
  219. + """Computes the (weighted) sum of the given values.
  220. +
  221. + For example, if values is [1, 3, 5, 7] then the sum is 16.
  222. + If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
  223. +
  224. + This metric creates one variable, `total`, that is used to compute the sum of
  225. + `values`. This is ultimately returned as `sum`.
  226. + If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0
  227. + to mask values.
  228. +
  229. + Standalone usage:
  230. + ```python
  231. + m = keras.metrics.Sum()
  232. + m.update_state([1, 3, 5, 7])
  233. + m.result()
  234. + ```
  235. + """
  236. +
  237. + def __init__(self, name='sum', dtype=None):
  238. + """Creates a `Sum` instance.
  239. + # Arguments
  240. + name: (Optional) string name of the metric instance.
  241. + dtype: (Optional) data type of the metric result.
  242. + """
  243. + super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
  244. + name=name, dtype=dtype)
  245. +
  246. +
  247. def binary_accuracy(y_true, y_pred):
  248. return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
  249. @@ -49,6 +244,395 @@ def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
  250. return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
  251. axis=-1)
  252. +class SensitivitySpecificityBase(Metric):
  253. + """Abstract base class for computing sensitivity and specificity.
  254. +
  255. + For additional information about specificity and sensitivity, see the
  256. + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
  257. + """
  258. +
  259. + def __init__(self, value, num_thresholds=200, name=None, dtype=None):
  260. + super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
  261. + if num_thresholds <= 0:
  262. + raise ValueError('`num_thresholds` must be > 0.')
  263. + self.value = value
  264. + self.true_positives = self.add_weight(
  265. + 'true_positives',
  266. + shape=(num_thresholds,),
  267. + initializer='zeros')
  268. + self.true_negatives = self.add_weight(
  269. + 'true_negatives',
  270. + shape=(num_thresholds,),
  271. + initializer='zeros')
  272. + self.false_positives = self.add_weight(
  273. + 'false_positives',
  274. + shape=(num_thresholds,),
  275. + initializer='zeros')
  276. + self.false_negatives = self.add_weight(
  277. + 'false_negatives',
  278. + shape=(num_thresholds,),
  279. + initializer='zeros')
  280. +
  281. + # Compute `num_thresholds` thresholds in [0, 1]
  282. + if num_thresholds == 1:
  283. + self.thresholds = [0.5]
  284. + else:
  285. + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
  286. + for i in range(num_thresholds - 2)]
  287. + self.thresholds = [0.0] + thresholds + [1.0]
  288. +
  289. + def update_state(self, y_true, y_pred, sample_weight=None):
  290. + return metrics_utils.update_confusion_matrix_variables(
  291. + {
  292. + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
  293. + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
  294. + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
  295. + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
  296. + },
  297. + y_true,
  298. + y_pred,
  299. + thresholds=self.thresholds,
  300. + sample_weight=sample_weight)
  301. +
  302. + def reset_states(self):
  303. + num_thresholds = len(self.thresholds)
  304. + K.batch_set_value(
  305. + [(v, np.zeros((num_thresholds,))) for v in self.variables])
  306. +
  307. +
  308. +class SensitivityAtSpecificity(SensitivitySpecificityBase):
  309. + """Computes the sensitivity at a given specificity.
  310. +
  311. + `Sensitivity` measures the proportion of actual positives that are correctly
  312. + identified as such (tp / (tp + fn)).
  313. + `Specificity` measures the proportion of actual negatives that are correctly
  314. + identified as such (tn / (tn + fp)).
  315. +
  316. + This metric creates four local variables, `true_positives`, `true_negatives`,
  317. + `false_positives` and `false_negatives` that are used to compute the
  318. + sensitivity at the given specificity. The threshold for the given specificity
  319. + value is computed and used to evaluate the corresponding sensitivity.
  320. +
  321. + If `sample_weight` is `None`, weights default to 1.
  322. + Use `sample_weight` of 0 to mask values.
  323. +
  324. + For additional information about specificity and sensitivity, see the
  325. + following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
  326. +
  327. + Usage with the compile API:
  328. +
  329. + ```python
  330. + model = keras.Model(inputs, outputs)
  331. + model.compile(
  332. + 'sgd',
  333. + loss='mse',
  334. + metrics=[keras.metrics.SensitivityAtSpecificity()])
  335. + ```
  336. +
  337. + # Arguments
  338. + specificity: A scalar value in range `[0, 1]`.
  339. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to
  340. + use for matching the given specificity.
  341. + name: (Optional) string name of the metric instance.
  342. + dtype: (Optional) data type of the metric result.
  343. + """
  344. +
  345. + def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
  346. + if specificity < 0 or specificity > 1:
  347. + raise ValueError('`specificity` must be in the range [0, 1].')
  348. + self.specificity = specificity
  349. + self.num_thresholds = num_thresholds
  350. + super(SensitivityAtSpecificity, self).__init__(
  351. + specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
  352. +
  353. + def result(self):
  354. + # Calculate specificities at all the thresholds.
  355. + specificities = K.switch(
  356. + K.greater(self.true_negatives + self.false_positives, 0),
  357. + (self.true_negatives / (self.true_negatives + self.false_positives)),
  358. + K.zeros_like(self.thresholds))
  359. +
  360. + # Find the index of the threshold where the specificity is closest to the
  361. + # given specificity.
  362. + min_index = K.argmin(
  363. + K.abs(specificities - self.value), axis=0)
  364. + min_index = K.cast(min_index, 'int32')
  365. +
  366. + # Compute sensitivity at that index.
  367. + return K.switch(
  368. + K.greater((self.true_positives[min_index] +
  369. + self.false_negatives[min_index]), 0),
  370. + (self.true_positives[min_index] /
  371. + (self.true_positives[min_index] + self.false_negatives[min_index])),
  372. + K.zeros_like(self.true_positives[min_index]))
  373. +
  374. + def get_config(self):
  375. + config = {
  376. + 'num_thresholds': self.num_thresholds,
  377. + 'specificity': self.specificity
  378. + }
  379. + base_config = super(SensitivityAtSpecificity, self).get_config()
  380. + return dict(list(base_config.items()) + list(config.items()))
  381. +
  382. +
  383. +class AUC(Metric):
  384. + """Computes the approximate AUC (Area under the curve) via a Riemann sum.
  385. +
  386. + This metric creates four local variables, `true_positives`, `true_negatives`,
  387. + `false_positives` and `false_negatives` that are used to compute the AUC.
  388. + To discretize the AUC curve, a linearly spaced set of thresholds is used to
  389. + compute pairs of recall and precision values. The area under the ROC-curve is
  390. + therefore computed using the height of the recall values by the false positive
  391. + rate, while the area under the PR-curve is the computed using the height of
  392. + the precision values by the recall.
  393. +
  394. + This value is ultimately returned as `auc`, an idempotent operation that
  395. + computes the area under a discretized curve of precision versus recall values
  396. + (computed using the aforementioned variables). The `num_thresholds` variable
  397. + controls the degree of discretization with larger numbers of thresholds more
  398. + closely approximating the true AUC. The quality of the approximation may vary
  399. + dramatically depending on `num_thresholds`. The `thresholds` parameter can be
  400. + used to manually specify thresholds which split the predictions more evenly.
  401. +
  402. + For best results, `predictions` should be distributed approximately uniformly
  403. + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
  404. + approximation may be poor if this is not the case. Setting `summation_method`
  405. + to 'minoring' or 'majoring' can help quantify the error in the approximation
  406. + by providing lower or upper bound estimate of the AUC.
  407. +
  408. + If `sample_weight` is `None`, weights default to 1.
  409. + Use `sample_weight` of 0 to mask values.
  410. +
  411. + Usage with the compile API:
  412. +
  413. + ```python
  414. + model = keras.Model(inputs, outputs)
  415. + model.compile('sgd', loss='mse', metrics=[keras.metrics.AUC()])
  416. + ```
  417. +
  418. + # Arguments
  419. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to
  420. + use when discretizing the roc curve. Values must be > 1.
  421. + curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
  422. + [default] or 'PR' for the Precision-Recall-curve.
  423. + summation_method: (Optional) Specifies the Riemann summation method used
  424. + (https://en.wikipedia.org/wiki/Riemann_sum): 'interpolation' [default],
  425. + applies mid-point summation scheme for `ROC`. For PR-AUC, interpolates
  426. + (true/false) positives but not the ratio that is precision (see Davis
  427. + & Goadrich 2006 for details); 'minoring' that applies left summation
  428. + for increasing intervals and right summation for decreasing intervals;
  429. + 'majoring' that does the opposite.
  430. + name: (Optional) string name of the metric instance.
  431. + dtype: (Optional) data type of the metric result.
  432. + thresholds: (Optional) A list of floating point values to use as the
  433. + thresholds for discretizing the curve. If set, the `num_thresholds`
  434. + parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
  435. + equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
  436. + be automatically included with these to correctly handle predictions
  437. + equal to exactly 0 or 1.
  438. + """
  439. +
  440. + def __init__(self,
  441. + num_thresholds=200,
  442. + curve='ROC',
  443. + summation_method='interpolation',
  444. + name=None,
  445. + dtype=None,
  446. + thresholds=None):
  447. + # Validate configurations.
  448. + if (isinstance(curve, metrics_utils.AUCCurve) and
  449. + curve not in list(metrics_utils.AUCCurve)):
  450. + raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
  451. + curve, list(metrics_utils.AUCCurve)))
  452. + if isinstance(
  453. + summation_method,
  454. + metrics_utils.AUCSummationMethod) and summation_method not in list(
  455. + metrics_utils.AUCSummationMethod):
  456. + raise ValueError(
  457. + 'Invalid summation method: "{}". Valid options are: "{}"'.format(
  458. + summation_method, list(metrics_utils.AUCSummationMethod)))
  459. +
  460. + # Update properties.
  461. + if thresholds is not None:
  462. + # If specified, use the supplied thresholds.
  463. + self.num_thresholds = len(thresholds) + 2
  464. + thresholds = sorted(thresholds)
  465. + else:
  466. + if num_thresholds <= 1:
  467. + raise ValueError('`num_thresholds` must be > 1.')
  468. +
  469. + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
  470. + # (0, 1).
  471. + self.num_thresholds = num_thresholds
  472. + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
  473. + for i in range(num_thresholds - 2)]
  474. +
  475. + # Add an endpoint "threshold" below zero and above one for either
  476. + # threshold method to account for floating point imprecisions.
  477. + self.thresholds = [0.0 - K.epsilon()] + thresholds + [1.0 + K.epsilon()]
  478. +
  479. + if isinstance(curve, metrics_utils.AUCCurve):
  480. + self.curve = curve
  481. + else:
  482. + self.curve = metrics_utils.AUCCurve.from_str(curve)
  483. + if isinstance(summation_method, metrics_utils.AUCSummationMethod):
  484. + self.summation_method = summation_method
  485. + else:
  486. + self.summation_method = metrics_utils.AUCSummationMethod.from_str(
  487. + summation_method)
  488. + super(AUC, self).__init__(name=name, dtype=dtype)
  489. +
  490. + # Create metric variables
  491. + self.true_positives = self.add_weight(
  492. + 'true_positives',
  493. + shape=(self.num_thresholds,),
  494. + initializer='zeros')
  495. + self.true_negatives = self.add_weight(
  496. + 'true_negatives',
  497. + shape=(self.num_thresholds,),
  498. + initializer='zeros')
  499. + self.false_positives = self.add_weight(
  500. + 'false_positives',
  501. + shape=(self.num_thresholds,),
  502. + initializer='zeros')
  503. + self.false_negatives = self.add_weight(
  504. + 'false_negatives',
  505. + shape=(self.num_thresholds,),
  506. + initializer='zeros')
  507. +
  508. + def update_state(self, y_true, y_pred, sample_weight=None):
  509. + return metrics_utils.update_confusion_matrix_variables({
  510. + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
  511. + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
  512. + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
  513. + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
  514. + }, y_true, y_pred, self.thresholds, sample_weight=sample_weight)
  515. +
  516. + def interpolate_pr_auc(self):
  517. + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
  518. +
  519. + https://www.biostat.wisc.edu/~page/rocpr.pdf
  520. +
  521. + Note here we derive & use a closed formula not present in the paper
  522. + as follows:
  523. +
  524. + Precision = TP / (TP + FP) = TP / P
  525. +
  526. + Modeling all of TP (true positive), FP (false positive) and their sum
  527. + P = TP + FP (predicted positive) as varying linearly within each interval
  528. + [A, B] between successive thresholds, we get
  529. +
  530. + Precision slope = dTP / dP
  531. + = (TP_B - TP_A) / (P_B - P_A)
  532. + = (TP - TP_A) / (P - P_A)
  533. + Precision = (TP_A + slope * (P - P_A)) / P
  534. +
  535. + The area within the interval is (slope / total_pos_weight) times
  536. +
  537. + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
  538. + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
  539. +
  540. + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
  541. +
  542. + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
  543. +
  544. + Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
  545. +
  546. + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
  547. +
  548. + where dTP == TP_B - TP_A.
  549. +
  550. + Note that when P_A == 0 the above calculation simplifies into
  551. +
  552. + int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
  553. +
  554. + which is really equivalent to imputing constant precision throughout the
  555. + first bucket having >0 true positives.
  556. +
  557. + # Returns
  558. + pr_auc: an approximation of the area under the P-R curve.
  559. + """
  560. + dtp = self.true_positives[:self.num_thresholds -
  561. + 1] - self.true_positives[1:]
  562. + p = self.true_positives + self.false_positives
  563. + dp = p[:self.num_thresholds - 1] - p[1:]
  564. +
  565. + prec_slope = dtp / K.maximum(dp, 0)
  566. + intercept = self.true_positives[1:] - (prec_slope * p[1:])
  567. +
  568. + # Logical and
  569. + pMin = K.expand_dims(p[:self.num_thresholds - 1] > 0, 0)
  570. + pMax = K.expand_dims(p[1:] > 0, 0)
  571. + are_different = K.concatenate([pMin, pMax], axis=0)
  572. + switch_condition = K.all(are_different, axis=0)
  573. +
  574. + safe_p_ratio = K.switch(
  575. + switch_condition,
  576. + p[:self.num_thresholds - 1] / K.maximum(p[1:], 0),
  577. + K.ones_like(p[1:]))
  578. +
  579. + numer = prec_slope * (dtp + intercept * K.log(safe_p_ratio))
  580. + denom = K.maximum(self.true_positives[1:] + self.false_negatives[1:], 0)
  581. + return K.sum((numer / denom))
  582. +
  583. + def result(self):
  584. + if (self.curve == metrics_utils.AUCCurve.PR and
  585. + (self.summation_method ==
  586. + metrics_utils.AUCSummationMethod.INTERPOLATION)):
  587. + # This use case is different and is handled separately.
  588. + return self.interpolate_pr_auc()
  589. +
  590. + # Set `x` and `y` values for the curves based on `curve` config.
  591. + recall = K.switch(
  592. + K.greater((self.true_positives), 0),
  593. + (self.true_positives /
  594. + (self.true_positives + self.false_negatives)),
  595. + K.zeros_like(self.true_positives))
  596. + if self.curve == metrics_utils.AUCCurve.ROC:
  597. + fp_rate = K.switch(
  598. + K.greater((self.false_positives), 0),
  599. + (self.false_positives /
  600. + (self.false_positives + self.true_negatives)),
  601. + K.zeros_like(self.false_positives))
  602. + x = fp_rate
  603. + y = recall
  604. + else: # curve == 'PR'.
  605. + precision = K.switch(
  606. + K.greater((self.true_positives), 0),
  607. + (self.true_positives / (self.true_positives + self.false_positives)),
  608. + K.zeros_like(self.true_positives))
  609. + x = recall
  610. + y = precision
  611. +
  612. + # Find the rectangle heights based on `summation_method`.
  613. + if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
  614. + # Note: the case ('PR', 'interpolation') has been handled above.
  615. + heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
  616. + elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
  617. + heights = K.minimum(y[:self.num_thresholds - 1], y[1:])
  618. + else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
  619. + heights = K.maximum(y[:self.num_thresholds - 1], y[1:])
  620. +
  621. + # Sum up the areas of all the rectangles.
  622. + return K.sum((x[:self.num_thresholds - 1] - x[1:]) * heights)
  623. +
  624. + def reset_states(self):
  625. + K.batch_set_value(
  626. + [(v, np.zeros((self.num_thresholds,))) for v in self.variables])
  627. +
  628. + def get_config(self):
  629. + config = {
  630. + 'num_thresholds': self.num_thresholds,
  631. + 'curve': self.curve.value,
  632. + 'summation_method': self.summation_method.value,
  633. + # We remove the endpoint thresholds as an inverse of how the thresholds
  634. + # were initialized. This ensures that a metric initialized from this
  635. + # config has the same thresholds.
  636. + 'thresholds': self.thresholds[1:-1],
  637. + }
  638. + base_config = super(AUC, self).get_config()
  639. + return dict(list(base_config.items()) + list(config.items()))
  640. +
  641. # Aliases
  642. diff --git a/keras/utils/__init__.py b/keras/utils/__init__.py
  643. index 8cc39d5..65af329 100644
  644. --- a/keras/utils/__init__.py
  645. +++ b/keras/utils/__init__.py
  646. @@ -4,6 +4,8 @@ from . import generic_utils
  647. from . import data_utils
  648. from . import io_utils
  649. from . import conv_utils
  650. +from . import losses_utils
  651. +from . import metrics_utils
  652. # Globally-importable utils.
  653. from .io_utils import HDF5Matrix
  654. diff --git a/keras/utils/losses_utils.py b/keras/utils/losses_utils.py
  655. new file mode 100644
  656. index 0000000..617ebb7
  657. --- /dev/null
  658. +++ b/keras/utils/losses_utils.py
  659. @@ -0,0 +1,177 @@
  660. +"""Utilities related to losses."""
  661. +from __future__ import absolute_import
  662. +from __future__ import division
  663. +from __future__ import print_function
  664. +
  665. +import numpy as np
  666. +
  667. +from .. import backend as K
  668. +
  669. +
  670. +class Reduction(object):
  671. + """Types of loss reduction.
  672. +
  673. + Contains the following values:
  674. +
  675. + * `NONE`: Un-reduced weighted losses with the same shape as input. When this
  676. + reduction type used with built-in Keras training loops like
  677. + `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but
  678. + the reported loss will be a scalar value.
  679. + * `SUM`: Scalar sum of weighted losses.
  680. + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
  681. + """
  682. +
  683. + NONE = 'none'
  684. + SUM = 'sum'
  685. + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
  686. +
  687. + @classmethod
  688. + def all(cls):
  689. + return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
  690. +
  691. + @classmethod
  692. + def validate(cls, key):
  693. + if key not in cls.all():
  694. + raise ValueError('Invalid Reduction Key %s.' % key)
  695. +
  696. +
  697. +def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
  698. + """Squeeze or expand last dimension if needed.
  699. +
  700. + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1.
  701. + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
  702. + from the new rank of `y_pred`.
  703. + If `sample_weight` is scalar, it is kept scalar.
  704. +
  705. + # Arguments
  706. + y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
  707. + y_true: Optional label `Tensor` whose dimensions match `y_pred`.
  708. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match
  709. + `y_pred`.
  710. +
  711. + # Returns
  712. + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
  713. + the last dimension squeezed, `sample_weight` could be extended by one
  714. + dimension.
  715. + """
  716. + if y_true is not None:
  717. + y_pred_rank = K.ndim(y_pred)
  718. + y_pred_shape = K.int_shape(y_pred)
  719. + y_true_rank = K.ndim(y_true)
  720. + y_true_shape = K.int_shape(y_true)
  721. +
  722. + if (y_pred_rank - y_true_rank == 1) and (y_pred_shape[-1] == 1):
  723. + y_pred = K.squeeze(y_pred, -1)
  724. + elif (y_true_rank - y_pred_rank == 1) and (y_true_shape[-1] == 1):
  725. + y_true = K.squeeze(y_true, -1)
  726. +
  727. + if sample_weight is None:
  728. + return y_pred, y_true
  729. +
  730. + y_pred_rank = K.ndim(y_pred)
  731. + weights_rank = K.ndim(sample_weight)
  732. + if weights_rank != 0:
  733. + if weights_rank - y_pred_rank == 1:
  734. + sample_weight = K.squeeze(sample_weight, -1)
  735. + elif y_pred_rank - weights_rank == 1:
  736. + sample_weight = K.expand_dims(sample_weight, -1)
  737. + return y_pred, y_true, sample_weight
  738. +
  739. +
  740. +def _num_elements(losses):
  741. + """Computes the number of elements in `losses` tensor."""
  742. + with K.name_scope('num_elements') as scope:
  743. + return K.cast(K.size(losses, name=scope), losses.dtype)
  744. +
  745. +
  746. +def reduce_weighted_loss(weighted_losses, reduction=Reduction.SUM_OVER_BATCH_SIZE):
  747. + """Reduces the individual weighted loss measurements."""
  748. + if reduction == Reduction.NONE:
  749. + loss = weighted_losses
  750. + else:
  751. + loss = K.sum(weighted_losses)
  752. + if reduction == Reduction.SUM_OVER_BATCH_SIZE:
  753. + loss = loss / _num_elements(weighted_losses)
  754. + return loss
  755. +
  756. +
  757. +def broadcast_weights(values, sample_weight):
  758. + # Broadcast weights if possible.
  759. + weights_shape = K.int_shape(sample_weight)
  760. + values_shape = K.int_shape(values)
  761. +
  762. + if values_shape != weights_shape:
  763. + weights_rank = K.ndim(sample_weight)
  764. + values_rank = K.ndim(values)
  765. +
  766. + # Raise error if ndim of weights is > values.
  767. + if weights_rank > values_rank:
  768. + raise ValueError(
  769. + 'Incompatible shapes: `values` {} vs `sample_weight` {}'.format(
  770. + values_shape, weights_shape))
  771. +
  772. + # Expand dim of weights to match ndim of values, if required.
  773. + for i in range(weights_rank, values_rank):
  774. + sample_weight = K.expand_dims(sample_weight, axis=i)
  775. +
  776. + if weights_shape is not None and values_shape is not None:
  777. + for i in range(weights_rank):
  778. + if (weights_shape[i] is not None and
  779. + values_shape[i] is not None and
  780. + weights_shape[i] != values_shape[i]):
  781. + # Cannot be broadcasted.
  782. + if weights_shape[i] != 1:
  783. + raise ValueError(
  784. + 'Incompatible shapes: `values` {} vs '
  785. + '`sample_weight` {}'.format(
  786. + values_shape, weights_shape))
  787. + sample_weight = K.repeat_elements(
  788. + sample_weight, values_shape[i], axis=i)
  789. + return sample_weight
  790. +
  791. +
  792. +def compute_weighted_loss(losses,
  793. + sample_weight=None,
  794. + reduction=Reduction.SUM_OVER_BATCH_SIZE,
  795. + name=None):
  796. + """Computes the weighted loss.
  797. +
  798. + # Arguments
  799. + losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
  800. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
  801. + ` losses`, or be broadcastable to `losses`.
  802. + reduction: (Optional) Type of Reduction to apply to loss.
  803. + Default value is `SUM_OVER_BATCH_SIZE`.
  804. + name: Optional name for the op.
  805. +
  806. + # Raises
  807. + ValueError: If the shape of `sample_weight` is not compatible with `losses`.
  808. +
  809. + # Returns
  810. + Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
  811. + `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
  812. + """
  813. + Reduction.validate(reduction)
  814. + if sample_weight is None:
  815. + sample_weight = 1.0
  816. + with K.name_scope(name or 'weighted_loss'):
  817. + input_dtype = K.dtype(losses)
  818. + losses = K.cast(losses, K.floatx())
  819. + sample_weight = K.cast(sample_weight, K.floatx())
  820. +
  821. + # Update dimensions of `sample_weight` to match with `losses` if possible.
  822. + losses, _, sample_weight = squeeze_or_expand_dimensions(
  823. + losses, None, sample_weight)
  824. +
  825. + # Broadcast weights if possible.
  826. + sample_weight = broadcast_weights(losses, sample_weight)
  827. +
  828. + # Apply weights to losses.
  829. + weighted_losses = sample_weight * losses
  830. +
  831. + # Apply reduction function to the individual weighted losses.
  832. + loss = reduce_weighted_loss(weighted_losses, reduction)
  833. + # Convert the result back to the input type.
  834. + loss = K.cast(loss, input_dtype)
  835. + return loss
  836. +
  837. diff --git a/keras/utils/metrics_utils.py b/keras/utils/metrics_utils.py
  838. new file mode 100644
  839. index 0000000..e6a5bb0
  840. --- /dev/null
  841. +++ b/keras/utils/metrics_utils.py
  842. @@ -0,0 +1,278 @@
  843. +"""Utilities related to metrics."""
  844. +from __future__ import absolute_import
  845. +from __future__ import division
  846. +from __future__ import print_function
  847. +
  848. +from enum import Enum
  849. +
  850. +from .. import backend as K
  851. +from . import losses_utils
  852. +
  853. +NEG_INF = -1e10
  854. +
  855. +class Reduction(object):
  856. + """Types of metrics reduction.
  857. + Contains the following values:
  858. + * `SUM`: Scalar sum of weighted values.
  859. + * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` of weighted values divided by
  860. + number of elements in values.
  861. + * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
  862. + """
  863. +
  864. + SUM = 'sum'
  865. + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
  866. + WEIGHTED_MEAN = 'weighted_mean'
  867. +
  868. +
  869. +def update_state_wrapper(update_state_fn):
  870. + """Decorator to wrap metric `update_state()` with `add_update()`.
  871. + # Arguments
  872. + update_state_fn: function that accumulates metric statistics.
  873. + # Returns
  874. + Decorated function that wraps `update_state_fn()` with `add_update()`.
  875. + """
  876. + def decorated(metric_obj, *args, **kwargs):
  877. + """Decorated function with `add_update()`."""
  878. +
  879. + update_op = update_state_fn(*args, **kwargs)
  880. + metric_obj.add_update(update_op)
  881. + return update_op
  882. +
  883. + return decorated
  884. +
  885. +def result_wrapper(result_fn):
  886. + """Decorator to wrap metric `result()` with identity op.
  887. + Wrapping result in identity so that control dependency between
  888. + update_op from `update_state` and result works in case result returns
  889. + a tensor.
  890. + # Arguments
  891. + result_fn: function that computes the metric result.
  892. + # Returns
  893. + Decorated function that wraps `result()` with identity op.
  894. + """
  895. + def decorated(metric_obj, *args, **kwargs):
  896. + result_t = K.identity(result_fn(*args, **kwargs))
  897. + metric_obj._call_result = result_t
  898. + result_t._is_metric = True
  899. + return result_t
  900. + return decorated
  901. +
  902. +
  903. +def to_list(x):
  904. + if isinstance(x, list):
  905. + return x
  906. + return [x]
  907. +
  908. +
  909. +def assert_thresholds_range(thresholds):
  910. + if thresholds is not None:
  911. + invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
  912. + if invalid_thresholds:
  913. + raise ValueError(
  914. + 'Threshold values must be in [0, 1]. Invalid values: {}'.format(
  915. + invalid_thresholds))
  916. +
  917. +
  918. +def parse_init_thresholds(thresholds, default_threshold=0.5):
  919. + if thresholds is not None:
  920. + assert_thresholds_range(to_list(thresholds))
  921. + thresholds = to_list(default_threshold if thresholds is None else thresholds)
  922. + return thresholds
  923. +
  924. +class ConfusionMatrix(Enum):
  925. + TRUE_POSITIVES = 'tp'
  926. + FALSE_POSITIVES = 'fp'
  927. + TRUE_NEGATIVES = 'tn'
  928. + FALSE_NEGATIVES = 'fn'
  929. +
  930. +class AUCCurve(Enum):
  931. + """Type of AUC Curve (ROC or PR)."""
  932. + ROC = 'ROC'
  933. + PR = 'PR'
  934. +
  935. + @staticmethod
  936. + def from_str(key):
  937. + if key in ('pr', 'PR'):
  938. + return AUCCurve.PR
  939. + elif key in ('roc', 'ROC'):
  940. + return AUCCurve.ROC
  941. + else:
  942. + raise ValueError('Invalid AUC curve value "%s".' % key)
  943. +
  944. +
  945. +class AUCSummationMethod(Enum):
  946. + """Type of AUC summation method.
  947. +
  948. + https://en.wikipedia.org/wiki/Riemann_sum)
  949. +
  950. + Contains the following values:
  951. + * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
  952. + `PR` curve, interpolates (true/false) positives but not the ratio that is
  953. + precision (see Davis & Goadrich 2006 for details).
  954. + * 'minoring': Applies left summation for increasing intervals and right
  955. + summation for decreasing intervals.
  956. + * 'majoring': Applies right summation for increasing intervals and left
  957. + summation for decreasing intervals.
  958. + """
  959. + INTERPOLATION = 'interpolation'
  960. + MAJORING = 'majoring'
  961. + MINORING = 'minoring'
  962. +
  963. + @staticmethod
  964. + def from_str(key):
  965. + if key in ('interpolation', 'Interpolation'):
  966. + return AUCSummationMethod.INTERPOLATION
  967. + elif key in ('majoring', 'Majoring'):
  968. + return AUCSummationMethod.MAJORING
  969. + elif key in ('minoring', 'Minoring'):
  970. + return AUCSummationMethod.MINORING
  971. + else:
  972. + raise ValueError('Invalid AUC summation method value "%s".' % key)
  973. +
  974. +def weighted_assign_add(label, pred, weights, var):
  975. + # Logical and
  976. + label = K.expand_dims(label, 0)
  977. + pred = K.expand_dims(pred, 0)
  978. + are_different = K.concatenate([label, pred], axis=0)
  979. + label_and_pred = K.all(are_different, axis=0)
  980. + label_and_pred = K.cast(label_and_pred, dtype=K.floatx())
  981. + if weights is not None:
  982. + label_and_pred *= weights
  983. + return var.assign_add(K.sum(label_and_pred, 1))
  984. +
  985. +def update_confusion_matrix_variables(variables_to_update,
  986. + y_true,
  987. + y_pred,
  988. + thresholds,
  989. + top_k=None,
  990. + class_id=None,
  991. + sample_weight=None):
  992. + """Returns op to update the given confusion matrix variables.
  993. + For every pair of values in y_true and y_pred:
  994. + true_positive: y_true == True and y_pred > thresholds
  995. + false_negatives: y_true == True and y_pred <= thresholds
  996. + true_negatives: y_true == False and y_pred <= thresholds
  997. + false_positive: y_true == False and y_pred > thresholds
  998. + The results will be weighted and added together. When multiple thresholds are
  999. + provided, we will repeat the same for every threshold.
  1000. + For estimation of these metrics over a stream of data, the function creates an
  1001. + `update_op` operation that updates the given variables.
  1002. + If `sample_weight` is `None`, weights default to 1.
  1003. + Use weights of 0 to mask values.
  1004. + # Arguments
  1005. + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
  1006. + and corresponding variables to update as values.
  1007. + y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
  1008. + y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
  1009. + the range `[0, 1]`.
  1010. + thresholds: A float value or a python list or tuple of float thresholds in
  1011. + `[0, 1]`, or NEG_INF (used when top_k is set).
  1012. + top_k: Optional int, indicates that the positive labels should be limited to
  1013. + the top k predictions.
  1014. + class_id: Optional int, limits the prediction and labels to the class
  1015. + specified by this argument.
  1016. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
  1017. + `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
  1018. + be either `1`, or the same as the corresponding `y_true` dimension).
  1019. + # Returns
  1020. + Update ops.
  1021. + # Raises
  1022. + ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
  1023. + `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
  1024. + `variables_to_update` contains invalid keys.
  1025. + """
  1026. + if variables_to_update is None:
  1027. + return
  1028. + y_true = K.cast(y_true, dtype=K.floatx())
  1029. + y_pred = K.cast(y_pred, dtype=K.floatx())
  1030. + if sample_weight is not None:
  1031. + sample_weight = K.cast(sample_weight, dtype=K.floatx())
  1032. +
  1033. + if not any(key
  1034. + for key in variables_to_update
  1035. + if key in list(ConfusionMatrix)):
  1036. + raise ValueError(
  1037. + 'Please provide at least one valid confusion matrix '
  1038. + 'variable to update. Valid variable key options are: "{}". '
  1039. + 'Received: "{}"'.format(
  1040. + list(ConfusionMatrix), variables_to_update.keys()))
  1041. +
  1042. + invalid_keys = [
  1043. + key for key in variables_to_update if key not in list(ConfusionMatrix)
  1044. + ]
  1045. + if invalid_keys:
  1046. + raise ValueError(
  1047. + 'Invalid keys: {}. Valid variable key options are: "{}"'.format(
  1048. + invalid_keys, list(ConfusionMatrix)))
  1049. +
  1050. + if sample_weight is None:
  1051. + y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
  1052. + y_pred, y_true=y_true)
  1053. + else:
  1054. + y_pred, y_true, sample_weight = (
  1055. + losses_utils.squeeze_or_expand_dimensions(
  1056. + y_pred, y_true=y_true, sample_weight=sample_weight))
  1057. +
  1058. + if top_k is not None:
  1059. + y_pred = _filter_top_k(y_pred, top_k)
  1060. + if class_id is not None:
  1061. + y_true = y_true[..., class_id]
  1062. + y_pred = y_pred[..., class_id]
  1063. +
  1064. + thresholds = to_list(thresholds)
  1065. + num_thresholds = len(thresholds)
  1066. + num_predictions = K.size(y_pred)
  1067. +
  1068. + # Reshape predictions and labels.
  1069. + predictions_2d = K.reshape(y_pred, [1, -1])
  1070. + labels_2d = K.reshape(
  1071. + K.cast(y_true, dtype='bool'), [1, -1])
  1072. +
  1073. + # Tile the thresholds for every prediction.
  1074. + thresh_tiled = K.tile(
  1075. + K.expand_dims(K.constant(thresholds), 1),
  1076. + K.stack([1, num_predictions]))
  1077. +
  1078. + # Tile the predictions for every threshold.
  1079. + preds_tiled = K.tile(predictions_2d, [num_thresholds, 1])
  1080. +
  1081. + # Compare predictions and threshold.
  1082. + pred_is_pos = K.greater(preds_tiled, thresh_tiled)
  1083. + pred_is_neg = K.greater(thresh_tiled, preds_tiled)
  1084. +
  1085. + # Tile labels by number of thresholds
  1086. + label_is_pos = K.tile(labels_2d, [num_thresholds, 1])
  1087. +
  1088. + if sample_weight is not None:
  1089. + weights = losses_utils.broadcast_weights(
  1090. + y_pred, K.cast(sample_weight, dtype=K.floatx()))
  1091. + weights_tiled = K.tile(
  1092. + K.reshape(weights, [1, -1]), [num_thresholds, 1])
  1093. + else:
  1094. + weights_tiled = None
  1095. +
  1096. + update_ops = []
  1097. + loop_vars = {
  1098. + ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
  1099. + }
  1100. + update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
  1101. + update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
  1102. + update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
  1103. +
  1104. + if update_fn or update_tn:
  1105. + loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
  1106. +
  1107. + if update_fp or update_tn:
  1108. + label_is_neg = K.equal(
  1109. + label_is_pos, K.zeros_like(label_is_pos, dtype=label_is_pos.dtype))
  1110. + loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
  1111. + if update_tn:
  1112. + loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
  1113. +
  1114. + for matrix_cond, (label, pred) in loop_vars.items():
  1115. + if matrix_cond in variables_to_update:
  1116. + update_ops.append(
  1117. + weighted_assign_add(label, pred, weights_tiled,
  1118. + variables_to_update[matrix_cond]))
  1119. + return update_ops
  1120. +
  1121. --
  1122. 2.26.2