项目作者: AlexBarbera

项目描述 :
Keras helper class for adaptative batch size annealing.
高级语言: Python
项目地址: git://github.com/AlexBarbera/KerasBatchSizeAnnealing.git
创建时间: 2018-06-04T22:36:01Z
项目社区:https://github.com/AlexBarbera/KerasBatchSizeAnnealing

开源协议:

下载


Keras BatchSizeAnnealing

This repository contains a wrapper class for adjusting the batch_size after aech epoch as shown on the paper Don’t Decay the Learning Rate, Increase the Batch Size by by Samuel L. Smith, Pieter-Jan Kindermans, Chris Ying, Quoc V. Le.

Train Example

A minimum example of working code would be:

  1. from BatchSizeAnnealing import BatchSizeAnnealing
  2. from keras.datasets import mnist
  3. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  4. def callback(epoch):
  5. return 32 * epoch / 10 + 32
  6. model = createModel()
  7. trainer = BatchSizeAnnealing(model, callback)
  8. history = trainer.train( x_train, y_train, validation_data=(x_test, y_test), epochs=EPOCH, verbose=1)

The constructor takes as arguments the model and the callback to the annealing per epoch.

Wrapper Example

This class can also be used as a wrapper for keras.model as it will redirect all methods to the model passed as parameter. ie:

  1. from BatchSizeAnnealing import BatchSizeAnnealing
  2. from keras.datasets import mnist
  3. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  4. def callback(epoch):
  5. return 32 * epoch / 10 + 32
  6. model = createModel()
  7. trainer = BatchSizeAnnealing(model, callback)
  8. trainer.summary()

This last funtion is equivalent to to model.summary()

Constuctor parameters

  1. class BatchSizeAnnealing(object):
  2. def __init__(self, model, callback, show_hist=False, keep_verbosity=False):
  3. ...
  • model: The model to be used as training (Can be from the functional API).
  • callback: Callback to get batch sizes in a specific epoch:
    1. def callback(epoch):
    2. return ...
  • show_hist: Show progress as bar while training,
  • keep_verbosity: keep the “verbosity” parameter passed to train.

    TODO

  • ☑ Add verbose for training every batch.
  • ☐ Fix verbosity per epoch from keras.model.
  • ☐ Add parameter and return types.