项目作者: FrancescoSaverioZuppichini

项目描述 :
A easy to use API to store outputs from forward/backward hooks in Pytorch
高级语言: Jupyter Notebook
项目地址: git://github.com/FrancescoSaverioZuppichini/PytorchModuleStorage.git
创建时间: 2019-07-29T12:15:59Z
项目社区:https://github.com/FrancescoSaverioZuppichini/PytorchModuleStorage

开源协议:MIT License

下载


PytorchModuleStorage

Easy to use API to store forward/backward features

Francesco Saverio Zuppichini

Install

  1. pip install git+https://github.com/FrancescoSaverioZuppichini/PytorchModuleStorage.git

Quick Start

You have a model, e.g. vgg19 and you want to store the features in the third layer given an input x.

alt

First, we need a model. We will load vgg19 from torchvision.models. Then, we create a random input x

  1. import torch
  2. from torchvision.models import vgg19
  3. from PytorchStorage import ForwardModuleStorage
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  5. cnn = vgg19(False).to(device).eval()

Then, we define a ForwardModuleStorage instance by passing the model and the list of layer we are interested on.

  1. storage = ForwardModuleStorage(cnn, [cnn.features[3]])

Finally, we can pass a input to the storage.

  1. x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
  2. storage(x) # pass the input to the storage
  3. storage[cnn.features[3]][0] # the features can be accessed by passing the layer as a key
  1. tensor([[[[0.0815, 0.0000, 0.0136, ..., 0.0435, 0.0058, 0.0584],
  2. [0.1270, 0.0873, 0.0800, ..., 0.0910, 0.0808, 0.0875],
  3. [0.0172, 0.0095, 0.1667, ..., 0.2503, 0.0938, 0.1044],
  4. ...,
  5. [0.0000, 0.0181, 0.0950, ..., 0.1760, 0.0261, 0.0092],
  6. [0.0533, 0.0043, 0.0625, ..., 0.0000, 0.0000, 0.0000],
  7. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  8. [[0.0776, 0.1942, 0.2467, ..., 0.1669, 0.0778, 0.0969],
  9. [0.1714, 0.1516, 0.3037, ..., 0.1950, 0.0428, 0.0892],
  10. [0.1219, 0.2611, 0.2902, ..., 0.1964, 0.2083, 0.2422],
  11. ...,
  12. [0.1813, 0.1193, 0.2079, ..., 0.3328, 0.4176, 0.2015],
  13. [0.0870, 0.2522, 0.1454, ..., 0.2726, 0.1916, 0.2314],
  14. [0.0250, 0.1256, 0.1301, ..., 0.1425, 0.1691, 0.0775]],
  15. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.1044],
  16. [0.0000, 0.0202, 0.0000, ..., 0.0000, 0.0873, 0.0908],
  17. [0.0000, 0.0000, 0.0000, ..., 0.0683, 0.0053, 0.1209],
  18. ...,
  19. [0.0000, 0.0000, 0.0818, ..., 0.0000, 0.0000, 0.1722],
  20. [0.0000, 0.0493, 0.0501, ..., 0.0112, 0.0000, 0.0864],
  21. [0.0000, 0.1314, 0.0904, ..., 0.1500, 0.0628, 0.2383]],
  22. ...,
  23. [[0.0000, 0.0915, 0.1819, ..., 0.1442, 0.0499, 0.0346],
  24. [0.0000, 0.0000, 0.0750, ..., 0.1607, 0.0883, 0.0000],
  25. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1648, 0.0250],
  26. ...,
  27. [0.0000, 0.0000, 0.1259, ..., 0.1193, 0.0573, 0.0096],
  28. [0.0000, 0.0472, 0.0000, ..., 0.0000, 0.0467, 0.0000],
  29. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  30. [[0.0000, 0.0000, 0.0154, ..., 0.0080, 0.0000, 0.0000],
  31. [0.0347, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  32. [0.1283, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  33. ...,
  34. [0.0510, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  35. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  36. [0.0130, 0.0165, 0.0000, ..., 0.0302, 0.0000, 0.0000]],
  37. [[0.0000, 0.0499, 0.0000, ..., 0.0221, 0.0180, 0.0000],
  38. [0.0730, 0.0000, 0.0925, ..., 0.1378, 0.0475, 0.0000],
  39. [0.0000, 0.0677, 0.0000, ..., 0.0000, 0.0070, 0.0000],
  40. ...,
  41. [0.0712, 0.0431, 0.0000, ..., 0.0420, 0.0116, 0.0086],
  42. [0.0000, 0.1240, 0.0121, ..., 0.2387, 0.0294, 0.0413],
  43. [0.0223, 0.0691, 0.0000, ..., 0.0000, 0.0000, 0.0000]]]],
  44. grad_fn=<ReluBackward1>)

The storage keeps an internal state (storage.state) where we can use the layers as key to access the stored value.

Hook to a list of layers

You can pass a list of layers and then access the stored outputs

  1. storage = ForwardModuleStorage(cnn, [cnn.features[3], cnn.features[5]])
  2. x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
  3. storage(x) # pass the input to the storage
  4. print(storage[cnn.features[3]][0].shape)
  5. print(storage[cnn.features[5]][0].shape)
  1. torch.Size([1, 64, 224, 224])
  2. torch.Size([1, 128, 112, 112])

Multiple Inputs

You can also pass multiple inputs, they will be stored using the call order

alt

  1. storage = ForwardModuleStorage(cnn, [cnn.features[3]])
  2. x = torch.rand(1,3,224,224).to(device) # random input, this can be an image
  3. y = torch.rand(1,3,224,224).to(device) # random input, this can be an image
  4. storage([x, y]) # pass the inputs to the storage
  5. print(storage[cnn.features[3]][0].shape) # x
  6. print(storage[cnn.features[3]][1].shape) # y
  1. torch.Size([1, 64, 224, 224])
  2. torch.Size([1, 64, 224, 224])

Different inputs for different layers

Image we want to run x on a set of layers and y on an other, this can be done by specify a dictionary of `{ NAME: [layers…], …}
alt

  1. storage = ForwardModuleStorage(cnn, {'style' : [cnn.features[5]], 'content' : [cnn.features[5], cnn.features[10]]})
  2. storage(x, 'style') # we run x only on the 'style' layers
  3. storage(y, 'content') # we run y only on the 'content' layers
  4. print(storage['style'])
  5. print(storage['style'][cnn.features[5]])
  1. MutipleKeysDict([(Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), tensor([[[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  2. [0.0000, 0.0383, 0.0042, ..., 0.0852, 0.0246, 0.1101],
  3. [0.0000, 0.0000, 0.1106, ..., 0.0000, 0.0107, 0.0487],
  4. ...,
  5. [0.0085, 0.0809, 0.0000, ..., 0.0000, 0.0012, 0.0018],
  6. [0.0000, 0.0817, 0.1753, ..., 0.0000, 0.0000, 0.0701],
  7. [0.0000, 0.1445, 0.1105, ..., 0.2428, 0.0418, 0.0803]],
  8. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  9. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  10. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  11. ...,
  12. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  13. [0.0400, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  14. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  15. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  16. [0.0000, 0.0731, 0.0316, ..., 0.0673, 0.0000, 0.0383],
  17. [0.0000, 0.0288, 0.0000, ..., 0.0499, 0.0000, 0.0573],
  18. ...,
  19. [0.0000, 0.0128, 0.0744, ..., 0.1250, 0.0000, 0.0023],
  20. [0.0000, 0.0000, 0.0000, ..., 0.0353, 0.0000, 0.0000],
  21. [0.0093, 0.1436, 0.1009, ..., 0.2187, 0.0988, 0.0693]],
  22. ...,
  23. [[0.1177, 0.0370, 0.2002, ..., 0.1878, 0.1076, 0.0000],
  24. [0.1045, 0.0090, 0.0000, ..., 0.0705, 0.0000, 0.0000],
  25. [0.1074, 0.1208, 0.0000, ..., 0.1038, 0.1378, 0.0000],
  26. ...,
  27. [0.0634, 0.0234, 0.0610, ..., 0.0955, 0.0977, 0.0000],
  28. [0.1097, 0.0563, 0.0000, ..., 0.0797, 0.0424, 0.0000],
  29. [0.0090, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  30. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0254],
  31. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  32. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  33. ...,
  34. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  35. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  36. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  37. [[0.0690, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  38. [0.1769, 0.0128, 0.1329, ..., 0.0733, 0.1435, 0.0000],
  39. [0.1478, 0.0476, 0.0000, ..., 0.0192, 0.0000, 0.0000],
  40. ...,
  41. [0.2258, 0.0908, 0.0621, ..., 0.1120, 0.0678, 0.0000],
  42. [0.1161, 0.0625, 0.0694, ..., 0.0365, 0.0000, 0.0000],
  43. [0.1360, 0.0890, 0.1442, ..., 0.1679, 0.1336, 0.0432]]]],
  44. grad_fn=<ReluBackward1>))])
  45. tensor([[[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  46. [0.0000, 0.0383, 0.0042, ..., 0.0852, 0.0246, 0.1101],
  47. [0.0000, 0.0000, 0.1106, ..., 0.0000, 0.0107, 0.0487],
  48. ...,
  49. [0.0085, 0.0809, 0.0000, ..., 0.0000, 0.0012, 0.0018],
  50. [0.0000, 0.0817, 0.1753, ..., 0.0000, 0.0000, 0.0701],
  51. [0.0000, 0.1445, 0.1105, ..., 0.2428, 0.0418, 0.0803]],
  52. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  53. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  54. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  55. ...,
  56. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  57. [0.0400, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  58. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  59. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  60. [0.0000, 0.0731, 0.0316, ..., 0.0673, 0.0000, 0.0383],
  61. [0.0000, 0.0288, 0.0000, ..., 0.0499, 0.0000, 0.0573],
  62. ...,
  63. [0.0000, 0.0128, 0.0744, ..., 0.1250, 0.0000, 0.0023],
  64. [0.0000, 0.0000, 0.0000, ..., 0.0353, 0.0000, 0.0000],
  65. [0.0093, 0.1436, 0.1009, ..., 0.2187, 0.0988, 0.0693]],
  66. ...,
  67. [[0.1177, 0.0370, 0.2002, ..., 0.1878, 0.1076, 0.0000],
  68. [0.1045, 0.0090, 0.0000, ..., 0.0705, 0.0000, 0.0000],
  69. [0.1074, 0.1208, 0.0000, ..., 0.1038, 0.1378, 0.0000],
  70. ...,
  71. [0.0634, 0.0234, 0.0610, ..., 0.0955, 0.0977, 0.0000],
  72. [0.1097, 0.0563, 0.0000, ..., 0.0797, 0.0424, 0.0000],
  73. [0.0090, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  74. [[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0254],
  75. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  76. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  77. ...,
  78. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  79. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  80. [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
  81. [[0.0690, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
  82. [0.1769, 0.0128, 0.1329, ..., 0.0733, 0.1435, 0.0000],
  83. [0.1478, 0.0476, 0.0000, ..., 0.0192, 0.0000, 0.0000],
  84. ...,
  85. [0.2258, 0.0908, 0.0621, ..., 0.1120, 0.0678, 0.0000],
  86. [0.1161, 0.0625, 0.0694, ..., 0.0365, 0.0000, 0.0000],
  87. [0.1360, 0.0890, 0.1442, ..., 0.1679, 0.1336, 0.0432]]]],
  88. grad_fn=<ReluBackward1>)

Backward

You can also store gradients by using BackwardModuleStorage

  1. from PytorchStorage import BackwardModuleStorage
  1. import torch.nn as nn
  2. # we don't need the module, just the layers
  3. storage = BackwardModuleStorage([cnn.features[3]])
  4. x = torch.rand(1,3,224,224).requires_grad_(True).to(device) # random input, this can be an image
  5. loss = nn.CrossEntropyLoss()
  6. # 1 is the ground truth
  7. output = loss(cnn(x), torch.tensor([1]))
  8. storage(output)
  9. # then we can use the layer to get the gradient out from it
  10. storage[cnn.features[3]]
  1. [(tensor([[[[ 1.6662e-05, 0.0000e+00, 9.1222e-06, ..., 1.2165e-07,
  2. 0.0000e+00, 0.0000e+00],
  3. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  4. 0.0000e+00, 1.8770e-05],
  5. [ 4.9425e-05, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  6. 0.0000e+00, 0.0000e+00],
  7. ...,
  8. [ 7.3107e-05, 0.0000e+00, 0.0000e+00, ..., -2.6335e-05,
  9. 0.0000e+00, 2.1168e-05],
  10. [ 1.0214e-07, 0.0000e+00, 8.3543e-06, ..., 0.0000e+00,
  11. 8.6060e-06, 0.0000e+00],
  12. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  13. 0.0000e+00, 0.0000e+00]],
  14. [[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.9192e-05,
  15. 0.0000e+00, 0.0000e+00],
  16. [ 0.0000e+00, -1.3629e-05, 0.0000e+00, ..., 0.0000e+00,
  17. -8.7888e-06, 0.0000e+00],
  18. [ 0.0000e+00, 0.0000e+00, -3.7738e-05, ..., 0.0000e+00,
  19. -3.6711e-05, 0.0000e+00],
  20. ...,
  21. [ 0.0000e+00, 0.0000e+00, 4.7797e-05, ..., 0.0000e+00,
  22. -1.3995e-05, 0.0000e+00],
  23. [ 0.0000e+00, 3.2237e-05, 0.0000e+00, ..., 1.3353e-05,
  24. 0.0000e+00, 2.6432e-05],
  25. [ 0.0000e+00, 0.0000e+00, -9.5113e-06, ..., 0.0000e+00,
  26. 0.0000e+00, 0.0000e+00]],
  27. [[ 4.5919e-06, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  28. 0.0000e+00, 0.0000e+00],
  29. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.2707e-05,
  30. 0.0000e+00, 6.5265e-06],
  31. [ 3.4605e-05, 0.0000e+00, 0.0000e+00, ..., -2.7972e-06,
  32. 0.0000e+00, -5.2525e-05],
  33. ...,
  34. [ 0.0000e+00, 3.6611e-06, 6.0328e-06, ..., 0.0000e+00,
  35. 0.0000e+00, 0.0000e+00],
  36. [ 0.0000e+00, 0.0000e+00, 9.9564e-07, ..., 2.1010e-05,
  37. 0.0000e+00, 0.0000e+00],
  38. [ 0.0000e+00, 1.1180e-05, 0.0000e+00, ..., 0.0000e+00,
  39. 0.0000e+00, -1.4692e-05]],
  40. ...,
  41. [[ 0.0000e+00, 3.1771e-05, -2.2892e-05, ..., 0.0000e+00,
  42. 0.0000e+00, 1.4811e-05],
  43. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  44. 0.0000e+00, 0.0000e+00],
  45. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  46. 0.0000e+00, 0.0000e+00],
  47. ...,
  48. [ 5.0065e-06, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  49. 0.0000e+00, 0.0000e+00],
  50. [ 4.7138e-05, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  51. -1.1021e-05, 0.0000e+00],
  52. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  53. 0.0000e+00, 0.0000e+00]],
  54. [[ 0.0000e+00, 0.0000e+00, 4.6386e-06, ..., 0.0000e+00,
  55. 0.0000e+00, 0.0000e+00],
  56. [-9.7505e-06, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  57. -9.5954e-07, 0.0000e+00],
  58. [ 1.1188e-05, 0.0000e+00, 1.7352e-05, ..., 0.0000e+00,
  59. 2.6517e-05, 0.0000e+00],
  60. ...,
  61. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  62. 0.0000e+00, 0.0000e+00],
  63. [-2.7686e-06, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  64. 0.0000e+00, 0.0000e+00],
  65. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  66. 1.7470e-05, 0.0000e+00]],
  67. [[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 1.3180e-06,
  68. 2.5051e-05, 0.0000e+00],
  69. [ 0.0000e+00, 8.3131e-06, 0.0000e+00, ..., 0.0000e+00,
  70. 0.0000e+00, 0.0000e+00],
  71. [ 0.0000e+00, 0.0000e+00, -2.1428e-05, ..., 0.0000e+00,
  72. -5.9600e-05, 0.0000e+00],
  73. ...,
  74. [ 0.0000e+00, 2.1640e-05, 0.0000e+00, ..., 0.0000e+00,
  75. 0.0000e+00, 0.0000e+00],
  76. [ 0.0000e+00, 4.6622e-05, 0.0000e+00, ..., 0.0000e+00,
  77. -2.1942e-05, 0.0000e+00],
  78. [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
  79. 0.0000e+00, 0.0000e+00]]]]),)]