项目作者: amir-abdi

项目描述 :
Framework in python for training/validation of Caffe models
高级语言: Python
项目地址: git://github.com/amir-abdi/deep-batch.git
创建时间: 2016-12-17T01:15:29Z
项目社区:https://github.com/amir-abdi/deep-batch

开源协议:Other

下载


Deep Batch: Deep Learning python wrapper for data handling [Caffe and Keras]

Our main motivation for developing the framework, is to deal with datasets where training data is not uniformly distributed among classes[values], both in generative and discriminative models. To mitigate this problem, we have developed a DataHanlder which holds label maps of the training set and generates mini-batches, on-the-fly, which holds equal (semi-equal) number of samples for each class[value]. As a result, from the models point of view, data is uniformly distributed among labels.

The given framework provides the following functionalities:

  • Reading data using a list_file index
  • On-the-fly batch selection via different strategies of:
    • uniform
    • random
    • iterative
  • On-the-fly data augmentation including random rotation and translation
  • Input data handling for
    • sequential data,
    • multi-stream networks with multiple input and multiple outputs
  • Maintaining training history and log

Project Hierarchy

The framework is implemented based on the following class diagram:

  1. RootModel # Abstract class [shared functionalities between Caffe and Keras]
  2. ├── RootCaffeModel # Abstract class [Caffe-specific functionality]
  3. ├── MyCaffeNet # A sample Caffe model which implements abstract methods
  4. ├── RootKerasModel # Abstract class [Keras-specific functionality]
  5. ├── MyKerasNet # A sample Keras model which implements abstract methods

Data handling are all handled in the DataHandler class (utilities/datahandler.py). These functionalities include, but are not limited to:

  • On-the-fly data preprocessing and augmentation
    • resize and crop
    • random rotation from uniform or normal distributions
    • random translation from uniform or normal distributions
  • On-the-fly uniform/random/iterative batch selection
  • Read image (via opencv) and MATLAB (via scipy) datasets

Datahandler expects the data to be provided in list files. Each list files should follow this design:

  1. fileaddress, label_0[,label_1, label_2, ...]
  2. fileaddress, label_0[,label_1, label_2, ...]
  3. fileaddress, label_0[,label_1, label_2, ...]
  4. ...

As demonstrated above, each sample can have multiple labels. However, one needs to decide the label used for training via the main_label_index hyper-parameter.

How to use

To take advantage of functionalities implemented in the framework, you need to write your own class which inherits from the RootCaffeModel or RootKerasModel, based on your preference. Your class needs to override the following methods:

  1. init_meta_data # defines all the hyperparameters for training, datahandler, optimizer, etc.
  2. net # define the network architecture
  3. train_validate # main training method which holds the training and validation loops
  4. evaluate[optional] # to evaluate performance of a trained model on a given test set
  5. [Other methods can be implemented based on necessity, such as output-specific accuracy calculator]

The following two sample classes are implemented to demonstrate functionalities of the framework and provide guidance:

  1. net
  2. ├── caffe_demo_net
  3. ├── CaffeDemoClassificationNet
  4. ├── keras_demo_net
  5. ├── KerasDemoMultiStreamRegressionNet

The CaffeDemoClassificationNet is a simple convolutional caffe model, which will be trained on the [Cifar10 dataset] (https://www.cs.toronto.edu/~kriz/cifar.html) if you run main_caffe_demonet.py.

The KerasDemoMultiStreamRegressionNet is a sample MultiStream, Sequential, Keras model, which is hypothetically trained on the [Cifar10 dataset] (https://www.cs.toronto.edu/~kriz/cifar.html). Please consider that the Cifar10 dataset is not a multi-input, multi-output dataset, however in the main_keras_demonet.py, we have devided the training and validation set into 4 sub-sets, assuming that each represent a different input type. The multi-stream network, designed only for demonstration purposes, has a shared architecture and shares the weights among the first few layers, while each stream has its own stream-specific layers. The model end with recurrent LSTM layers; however, since Cifar10 is not a sequential dataset, we faked that by considering the three channels of the color images as three different frames of a sequencial video. KerasDemoMultiStreamRegressionNet is implemented only to demonstrate capabilities of the Deep Batch framework and the trained model shall not be trusted.

Library dependencies

Deep Batch depends on the following libraries:

  • caffe
  • keras
  • tensorflow
  • cv2
  • numpy
  • scipy
  • sklearn
  • matplotlib
  • json
  • PIL (Python Imaging Library)

Limitations

  • So far, we have only developed for the TensorFlow backend of Keras.
  • Data handler accepts the data only in list index files and does not directly read contents of given directories
  • Label types, other than single value numbers (such as mask-images), are not handled.