项目作者: kuan-wang

项目描述 :
MobileNetV3 in pytorch and ImageNet pretrained models
高级语言: Python
项目地址: git://github.com/kuan-wang/pytorch-mobilenet-v3.git
创建时间: 2019-05-08T18:16:57Z
项目社区:https://github.com/kuan-wang/pytorch-mobilenet-v3

开源协议:Apache License 2.0

下载


A PyTorch implementation of MobileNetV3

This is a PyTorch implementation of MobileNetV3 architecture as described in the paper Searching for MobileNetV3.

Some details may be different from the original paper, welcome to discuss and help me figure it out.

  • [NEW] The pretrained model of small version mobilenet-v3 is online, accuracy achieves the same as paper.
  • [NEW] The paper updated on 17 May, so I renew the codes for that, but there still are some bugs.
  • [NEW] I remove the se before the global avg_pool (the paper may add it in error), and now the model size is close to paper.

Training & Accuracy

training setting:

  1. number of epochs: 150
  2. learning rate schedule: cosine learning rate, initial lr=0.05
  3. weight decay: 4e-5
  4. remove dropout
  5. batch size: 256

MobileNetV3 large

Madds Parameters Top1-acc Pretrained Model
Offical 1.0 219 M 5.4 M 75.2% -
Offical 0.75 155 M 4 M 73.3% -
Ours 1.0 224 M 5.48 M 72.8% -
Ours 0.75 148 M 3.91 M - -

MobileNetV3 small

Madds Parameters Top1-acc Pretrained Model
Offical 1.0 66 M 2.9 M 67.4% -
Offical 0.75 44 M 2.4 M 65.4% -
Ours 1.0 63 M 2.94 M 67.4% [google drive]
Ours 0.75 46 M 2.38 M - -

Usage

Pretrained models are still training …

  1. # pytorch 1.0.1
  2. # large
  3. net_large = mobilenetv3(mode='large')
  4. # small
  5. net_small = mobilenetv3(mode='small')
  6. state_dict = torch.load('mobilenetv3_small_67.4.pth.tar')
  7. net_small.load_state_dict(state_dict)

Data Pre-processing

I used the following code for data pre-processing on ImageNet:

  1. normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  2. std=[0.229, 0.224, 0.225])
  3. input_size = 224
  4. train_loader = torch.utils.data.DataLoader(
  5. datasets.ImageFolder(
  6. traindir, transforms.Compose([
  7. transforms.RandomResizedCrop(input_size),
  8. transforms.RandomHorizontalFlip(),
  9. transforms.ToTensor(),
  10. normalize,
  11. ])),
  12. batch_size=batch_size, shuffle=True,
  13. num_workers=n_worker, pin_memory=True)
  14. val_loader = torch.utils.data.DataLoader(
  15. datasets.ImageFolder(valdir, transforms.Compose([
  16. transforms.Resize(int(input_size/0.875)),
  17. transforms.CenterCrop(input_size),
  18. transforms.ToTensor(),
  19. normalize,
  20. ])),
  21. batch_size=batch_size, shuffle=False,
  22. num_workers=n_worker, pin_memory=True)