MAML Implementation using Pytorch-lightning
This repository is the reimplementation
of MAML (Model-Agnostic Meta-Learning)
algorithm. Differentiable optimizers are handled
by Higher library
and NN-template is used for structuring
the project. The default settings are used for training on Omniglot (5-way
5-shot) problem. It can be easily extended for other few-shot datasets thanks to
Torchmeta library.
On Local Machine
git clone https://github.com/rcmalli/lightning-maml.git
cd ./lightning-maml/
pip install -r requirements.txt
.env
file containing the info given below using your.env.template
file.
export DATASET_PATH="/your/project/root/data/"
export WANDB_ENTITY="USERNAME"
export WANDB_API_KEY="KEY"
python3 src/run.py train.pl_trainer.gpus=1
On Google Colab
Few-shot learning using this dataset is easy task to overfit or learn for
MAML algorithm.
Metatrain | Metavalidation | |||||
---|---|---|---|---|---|---|
Algorithm | Model | inner_steps | inner accuracy | outer accuracy | inner accuracy | outer accuracy |
MAML | OmniConv | 1 | 0.992 | 0.992 | 0.98 | 0.98 |
MAML | OmniConv | 5 | 1.0 | 1.0 | 1.0 | 1.0 |
Inside ‘conf’ folder, you can change all the settings depending on your problem
or dataset. The default parameters are set for Omniglot dataset. Here are some
examples for customization:
python3 src/run.py train.pl_trainer.gpus=0 train.pl_trainer.fast_dev_run=true
python3 src/run.py train.pl_trainer.gpus=1 train.pl_trainer.max_epochs=1000 \
data.datamodule.num_inner_steps=5
python3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,20 -m
If you want to try a different dataset (ex. MiniImageNet), you can copy
default.yaml file inside conf/data
to miniimagenet.yaml
and edit these
lines :
datamodule:
_target_: pl.datamodule.MetaDataModule
datasets:
train:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_train: True
download: True
val:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_val: True
download: True
test:
_target_: torchmeta.datasets.MiniImagenet
root: ${env:DATASET_PATH}
meta_test: True
download: True
# you may need to update data augmentation and preprocessing steps also!!!
Run the experiment as follows:
python3 src/run.py data=miniimagenet
If you plant to implement a new variant of MAML algorithm (for example
MAML++) you can start by extending default lightning module and its step
function.
There are few required modifications run meta-learning algorithm using
pytorch-lightning as high-level library
In supervised learning we have M
mini-batches for each epoch. However, we
have N
tasks for single meta-batch in meta learning settings. We have to
set our dataloader length to 1 otherwise, the dataloader will indefinitely
sample from the dataset.
Apart from traditional test phase of supervised learning, we need gradient
computation also in test phase. Currently, pytorch-lightning does not allow
you to enable gradient computation by settings, you have to add single line
to your beginning of test and validation steps as following:
torch.set_grad_enabled(True)
automatic_optimization=False
and add following lines to handle backward
self.manual_backward(outer_loss, outer_optimizer)
outer_optimizer.step()