你必须告诉TensorFlow你想跟踪你的损失。您只是将图表添加到编写器中。例如,您可以执行此操作来跟踪您的损失:
loss = ... (your def) tf.summary.scalar('MyLoss', loss) # ... maybe add some other variables (you can also make histograms, images, etc. via tf.summary.historam(...)) summ = tf.summary.merge_all()
在您的会话中,您可以像创建一样创建编写器。然后,您必须评估摘要操作并将其添加到编写器。但是,您应该在训练循环之外创建编写器,因为您不希望每次迭代都有编写器。您在迭代中提供迭代作为参数 add_summary 方法。
add_summary
saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter("outputLogs", sess.graph) for iteration in range(int(n_epochs*train_set_size/batch_size)): x_batch, y_batch = get_next_batch(batch_size) # fetch the next training batch [_, s] = sess.run([training_op, summ], feed_dict={X: x_batch, y: y_batch}) writer.add_summary(s, iteration) if iteration % int(5*train_set_size/batch_size) == 0: mse_train = loss.eval(feed_dict={X: x_train, y: y_train}) mse_valid = loss.eval(feed_dict={X: x_valid, y: y_valid}) print('%.2f epochs: MSE train/valid = %.10f/%.10f'%( iteration*batch_size/train_set_size, mse_train, mse_valid)) save_path = saver.save(sess, "models\\model"+str(iteration)+".ckpt")
您的培训代码应如下所示。