项目作者: rcmalli

项目描述 :
MAML Implementation using Pytorch-lightning
高级语言: Python
项目地址: git://github.com/rcmalli/lightning-maml.git
创建时间: 2021-03-19T00:06:12Z
项目社区:https://github.com/rcmalli/lightning-maml

开源协议:MIT License

下载


Pytorch Lightning MAML Implementation


PyTorch
Lightning
Conf: hydra
Logging: wandb
Code style: black

This repository is the reimplementation
of MAML (Model-Agnostic Meta-Learning)
algorithm. Differentiable optimizers are handled
by Higher library
and NN-template is used for structuring
the project. The default settings are used for training on Omniglot (5-way
5-shot) problem. It can be easily extended for other few-shot datasets thanks to
Torchmeta library.

Quickstart

On Local Machine

  1. Download and install dependencies
  1. git clone https://github.com/rcmalli/lightning-maml.git
  2. cd ./lightning-maml/
  3. pip install -r requirements.txt
  1. Create .env file containing the info given below using your
    own Wandb. ai
    account to track experiments. You can use .env.template file.
  1. export DATASET_PATH="/your/project/root/data/"
  2. export WANDB_ENTITY="USERNAME"
  3. export WANDB_API_KEY="KEY"
  1. Run the experiment
  1. python3 src/run.py train.pl_trainer.gpus=1

On Google Colab

Google Colab

Results

Omniglot (5-way 5-shot)

Few-shot learning using this dataset is easy task to overfit or learn for
MAML algorithm.






































Metatrain Metavalidation
Algorithm Model inner_steps inner accuracy outer accuracy inner accuracy outer accuracy
MAML OmniConv 1 0.992 0.992 0.98 0.98
MAML OmniConv 5 1.0 1.0 1.0 1.0

Customization

Inside ‘conf’ folder, you can change all the settings depending on your problem
or dataset. The default parameters are set for Omniglot dataset. Here are some
examples for customization:

Debug on local machine without GPU

  1. python3 src/run.py train.pl_trainer.gpus=0 train.pl_trainer.fast_dev_run=true

Running more inner_steps and more epochs

  1. python3 src/run.py train.pl_trainer.gpus=1 train.pl_trainer.max_epochs=1000 \
  2. data.datamodule.num_inner_steps=5

Running weep of multiple runs

  1. python3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,20 -m

Using different dataset from Torchmeta

If you want to try a different dataset (ex. MiniImageNet), you can copy
default.yaml file inside conf/data to miniimagenet.yaml and edit these
lines :

  1. datamodule:
  2. _target_: pl.datamodule.MetaDataModule
  3. datasets:
  4. train:
  5. _target_: torchmeta.datasets.MiniImagenet
  6. root: ${env:DATASET_PATH}
  7. meta_train: True
  8. download: True
  9. val:
  10. _target_: torchmeta.datasets.MiniImagenet
  11. root: ${env:DATASET_PATH}
  12. meta_val: True
  13. download: True
  14. test:
  15. _target_: torchmeta.datasets.MiniImagenet
  16. root: ${env:DATASET_PATH}
  17. meta_test: True
  18. download: True
  19. # you may need to update data augmentation and preprocessing steps also!!!

Run the experiment as follows:

  1. python3 src/run.py data=miniimagenet

Implementing a different meta learning algorithm

If you plant to implement a new variant of MAML algorithm (for example
MAML++) you can start by extending default lightning module and its step
function.

Notes

There are few required modifications run meta-learning algorithm using
pytorch-lightning as high-level library

  1. In supervised learning we have M mini-batches for each epoch. However, we
    have N tasks for single meta-batch in meta learning settings. We have to
    set our dataloader length to 1 otherwise, the dataloader will indefinitely
    sample from the dataset.

  2. Apart from traditional test phase of supervised learning, we need gradient
    computation also in test phase. Currently, pytorch-lightning does not allow
    you to enable gradient computation by settings, you have to add single line
    to your beginning of test and validation steps as following:

    1. torch.set_grad_enabled(True)
  3. In MAML algorithm, we have two different optimizers to train our model. Inner
    optimizer must be differentiable and outer optimizer should update model
    using updated weights inside inner iteration from support set and updates
    from query set. In Pytorch-lightning optimizer are handled and weight updates
    are done automatically. To disable this behaviour, we have to
    set automatic_optimization=False and add following lines to handle backward
    computations manually:
    1. self.manual_backward(outer_loss, outer_optimizer)
    2. outer_optimizer.step()

References