项目作者: titu1994

项目描述 :
Implementation of Squeeze and Excitation Networks in Keras
高级语言: Python
项目地址: git://github.com/titu1994/keras-squeeze-excite-network.git
创建时间: 2017-08-27T05:05:14Z
项目社区:https://github.com/titu1994/keras-squeeze-excite-network

开源协议:MIT License

下载


Squeeze and Excitation Networks in Keras

Implementation of Squeeze and Excitation Networks in Keras 2.0.3+.

squeeze-excite-block

Models

Current models supported :

  • SE-ResNet. Custom ResNets can be built using the SEResNet model builder, whereas prebuilt Resnet models such as SEResNet50, SEResNet101 and SEResNet154 can also be built directly.
  • SE-InceptionV3
  • SE-Inception-ResNet-v2
  • SE-ResNeXt

Additional models (not from the paper, not verified if they improve performance)

  • SE-MobileNets
  • SE-DenseNet - Custom SE-DenseNets can be built using SEDenseNet model builder, whereas prebuilt SEDenseNet models such as SEDenseNetImageNet121, SEDenseNetImageNet169, SEDenseNetImageNet161, SEDenseNetImageNet201 and SEDenseNetImageNet264 can be build DenseNet in ImageNet configuration. To use SEDenseNet in CIFAR mode, use the SEDenseNet model builder.

Squeeze and Excitation block

The block is simple to implement in Keras. It composes of a GlobalAveragePooling2D, 2 Dense blocks and an elementwise multiplication.
Shape inference can be done automatically in Keras. It can be imported from se.py.

  1. from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Permute, multiply
  2. import tensorflow.keras.backend as K
  3. def squeeze_excite_block(tensor, ratio=16):
  4. init = tensor
  5. channel_axis = 1 if K.image_data_format() == "channels_first" else -1
  6. filters = init._keras_shape[channel_axis]
  7. se_shape = (1, 1, filters)
  8. se = GlobalAveragePooling2D()(init)
  9. se = Reshape(se_shape)(se)
  10. se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
  11. se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
  12. if K.image_data_format() == 'channels_first':
  13. se = Permute((3, 1, 2))(se)
  14. x = multiply([init, se])
  15. return x

Addition of Squeeze and Excitation blocks to Inception and ResNet blocks

se-architectures SE-ResNet-architecture