项目作者: zabir-nabil

项目描述 :
Extension of the `Attention Augmented Convolutional Networks` paper for 1-D convolution operation.
高级语言: Jupyter Notebook
项目地址: git://github.com/zabir-nabil/keras-attn_aug_cnn.git
创建时间: 2019-09-03T12:09:43Z
项目社区:https://github.com/zabir-nabil/keras-attn_aug_cnn

开源协议:MIT License

下载


keras-attn_aug_cnn

Extension of the Attention Augmented Convolutional Networks paper for hacky 1-D convolution operation implementation.
Can be used in tensorflow graph too.

Properties

  1. depth_k | filters, depth_v | filters, Nh | depth_k, Nh | filters-depth_v

1-D CNN

  1. from aug_attn import *
  2. from tensorflow.keras.layers import Input
  3. from tensorflow.keras.models import Model
  4. ip = Input(shape=(None, 10))
  5. cnn1 = Conv1D(filters = 10, kernel_size=3, strides=1,padding='same')(ip)
  6. x = augmented_conv1d(cnn1, shape = (32, 10), filters=20, kernel_size=5,
  7. strides = 1,
  8. padding = 'causal', # if causal convolution is needed
  9. depth_k=4, depth_v=4,
  10. num_heads=4, relative_encodings=True)
  11. # depth_k | filters, depth_v | filters, Nh | depth_k, Nh | filters-depth_v
  12. model = Model(ip, x)
  13. model.summary()
  14. x = tf.ones((1, 32, 10))
  15. print(x.shape)
  16. y = model(x)
  17. print(y.shape)
  1. Model: "model_2"
  2. __________________________________________________________________________________________________
  3. Layer (type) Output Shape Param # Connected to
  4. ==================================================================================================
  5. input_3 (InputLayer) [(None, None, 10)] 0
  6. __________________________________________________________________________________________________
  7. conv1d_8 (Conv1D) (None, None, 10) 310 input_3[0][0]
  8. __________________________________________________________________________________________________
  9. conv1d_10 (Conv1D) (None, None, 12) 132 conv1d_8[0][0]
  10. __________________________________________________________________________________________________
  11. reshape_11 (Reshape) (None, 32, 1, 12) 0 conv1d_10[0][0]
  12. __________________________________________________________________________________________________
  13. attention_augmentation2d_2 (Att (None, None, None, N 64 reshape_11[0][0]
  14. __________________________________________________________________________________________________
  15. reshape_12 (Reshape) (None, 32, 4) 0 attention_augmentation2d_2[0][0]
  16. __________________________________________________________________________________________________
  17. conv1d_9 (Conv1D) (None, None, 16) 816 conv1d_8[0][0]
  18. __________________________________________________________________________________________________
  19. conv1d_11 (Conv1D) (None, 32, 4) 20 reshape_12[0][0]
  20. __________________________________________________________________________________________________
  21. reshape_10 (Reshape) (None, 32, 1, 16) 0 conv1d_9[0][0]
  22. __________________________________________________________________________________________________
  23. reshape_13 (Reshape) (None, 32, 1, 4) 0 conv1d_11[0][0]
  24. __________________________________________________________________________________________________
  25. concatenate_2 (Concatenate) (None, 32, 1, 20) 0 reshape_10[0][0]
  26. reshape_13[0][0]
  27. __________________________________________________________________________________________________
  28. reshape_14 (Reshape) (None, 32, 20) 0 concatenate_2[0][0]
  29. ==================================================================================================
  30. Total params: 1,342
  31. Trainable params: 1,342
  32. Non-trainable params: 0
  33. __________________________________________________________________________________________________
  34. (1, 32, 10)
  35. (1, 32, 20)

2-D CNN

  1. from aug_attn import *
  2. from tensorflow.keras.layers import Input
  3. from tensorflow.keras.models import Model
  4. ip = Input(shape=(32, 32, 10))
  5. cnn1 = Conv2D(filters = 10, kernel_size=3, strides=1,padding='same')(ip)
  6. x = augmented_conv2d(cnn1, filters=20, kernel_size=5, # shape parameter is not needed
  7. strides = 1,
  8. depth_k=4, depth_v=4, # padding is by default, same
  9. num_heads=4, relative_encodings=True)
  10. # depth_k | filters, depth_v | filters, Nh | depth_k, Nh | filters-depth_v
  11. model = Model(ip, x)
  12. model.summary()
  13. x = tf.ones((1, 32, 32, 10))
  14. print(x.shape)
  15. y = model(x)
  16. print(y.shape)
  1. __________________________________________________________________________________________________
  2. Layer (type) Output Shape Param # Connected to
  3. ==================================================================================================
  4. input_16 (InputLayer) (None, 32, 32, 10) 0
  5. __________________________________________________________________________________________________
  6. conv2d_11 (Conv2D) (None, 32, 32, 10) 910 input_16[0][0]
  7. __________________________________________________________________________________________________
  8. conv2d_13 (Conv2D) (None, 32, 32, 12) 132 conv2d_11[0][0]
  9. __________________________________________________________________________________________________
  10. attention_augmentation2d_14 (At (None, 32, 32, 4) 126 conv2d_13[0][0]
  11. __________________________________________________________________________________________________
  12. conv2d_12 (Conv2D) (None, 32, 32, 16) 4016 conv2d_11[0][0]
  13. __________________________________________________________________________________________________
  14. conv2d_14 (Conv2D) (None, 32, 32, 4) 20 attention_augmentation2d_14[0][0]
  15. __________________________________________________________________________________________________
  16. concatenate_14 (Concatenate) (None, 32, 32, 20) 0 conv2d_12[0][0]
  17. conv2d_14[0][0]
  18. ==================================================================================================
  19. Total params: 5,204
  20. Trainable params: 5,204
  21. Non-trainable params: 0
  22. __________________________________________________________________________________________________
  23. (1, 32, 32, 10)
  24. (1, 32, 32, 20)

Implementations