项目作者: guillaumegenthial

项目描述 :
Multi-class metrics for Tensorflow
高级语言: Python
项目地址: git://github.com/guillaumegenthial/tf_metrics.git
创建时间: 2018-09-16T17:37:21Z
项目社区:https://github.com/guillaumegenthial/tf_metrics

开源协议:Apache License 2.0

下载


TF Metrics

Build Status

Multi-class metrics for Tensorflow, similar to scikit-learn multi-class metrics.

Thank you all for making this project live (50-100 clones/day 😎). Contributions welcome!

Install

To add tf_metrics to your current python environment, run

  1. pip install git+https://github.com/guillaumegenthial/tf_metrics.git

For a more advanced use (editable mode, for developers)

  1. git clone https://github.com/guillaumegenthial/tf_metrics.git
  2. cd tf_metrics
  3. pip install -r requirements.txt

Example

Pre-requisite: understand the general tf.metrics API. See for instance the official guide on custom estimators or the official documentation.

Simple example

  1. import tensorflow as tf
  2. import tf_metrics
  3. y_true = [0, 1, 0, 0, 0, 2, 3, 0, 0, 1]
  4. y_pred = [0, 1, 0, 0, 1, 2, 0, 3, 3, 1]
  5. pos_indices = [1, 2, 3] # Class 0 is the 'negative' class
  6. num_classes = 4
  7. average = 'micro'
  8. # Tuple of (value, update_op)
  9. precision = tf_metrics.precision(
  10. y_true, y_pred, num_classes, pos_indices, average=average)
  11. recall = tf_metrics.recall(
  12. y_true, y_pred, num_classes, pos_indices, average=average)
  13. f2 = tf_metrics.fbeta(
  14. y_true, y_pred, num_classes, pos_indices, average=average, beta=2)
  15. f1 = tf_metrics.f1(
  16. y_true, y_pred, num_classes, pos_indices, average=average)
  17. # Run the update op and get the updated value
  18. with tf.Session() as sess:
  19. sess.run(tf.local_variables_initializer())
  20. sess.run(precision[1])

If you want to use it with tf.estimator.Estimator, add to your model_fn

  1. metrics = {
  2. 'precision': precision,
  3. 'recall': recall,
  4. 'f1': f1,
  5. 'f2': f2
  6. }
  7. # For Tensorboard
  8. for metric_name, metric in metrics.items():
  9. tf.summary.scalar(metric_name, metric[1])
  10. if mode == tf.estimator.ModeKeys.EVAL:
  11. return tf.estimator.EstimatorSpec(
  12. mode, loss=loss, eval_metric_ops=metrics)