Invert and perturb GAN images for test-time ensembling
Project Page | Paper | Bibtex
Ensembling with Deep Generative Views. \
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang \
CVPR 2021
Table of Contents:
We project an input image into the latent space of a pre-trained GAN and perturb it slightly to obtain modifications of the input image. These alternative views from the GAN are ensembled at test-time, together with the original image, in a downstream classification task.
To synthesize deep generative views, we first align (Aligned Input) and reconstruct an image by finding the corresponding latent code in StyleGAN2 (GAN Reconstruction). We then investigate different approaches to produce image variations using the GAN, such as style-mixing on fine layers (Style-mix Fine), which predominantly changes color, or coarse layers (Style-mix Coarse), which changes pose.
This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.
Clone this repo:
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling
Install dependencies:
environment.yml
file listing the dependencies. You can create the Conda environment using:
conda env create -f environment.yml
Download resources:
Fetch the resources by running
bash resources/download_resources.sh
Note, Optional: to run the StyleGAN ID-invert models, the models need to be downloaded separately. Follow the directions here to obtain styleganinv_ffhq256_encoder.pth
and styleganinv_ffhq256_encoder.pth
, and place them in models/pretrain
dataset/celebahq/images/images
.dataset/cars/images/images
and the devkit in dataset/cars/devkit
.An example of the directory organization is below:
dataset/celebahq/
images/images/
000004.png
000009.png
000014.png
...
latents/
latents_idinvert/
dataset/cars/
devkit/
cars_meta.mat
cars_test_annos.mat
cars_train_annos.mat
...
images/images/
00001.jpg
00002.jpg
00003.jpg
...
latents/
dataset/catface/
images/
latents/
dataset/cifar10/
cifar-10-batches-py/
latents/
Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb
.
import data
from networks import domain_generator
dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)
index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)
Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh
, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling
.
The provided notebooks are:
notebooks/demo.ipynb
: basic usage examplenotebooks/evaluate_ensemble.ipynb
: plot classification test accuracy as a function of ensemble weightnotebooks/plot_precomputed_evaluations.ipynb
: notebook to generate figures in paperThe full pipeline contains three main parts:
Examples for each step of the pipeline are contained in the following scripts:
bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh
To add to the pipeline:
data/
directory, add the dataset in data/__init__.py
and create the dataset class and transformation functions. See data/data_*.py
for examples.networks/domain_generators.py
to add the generator in domain_generators.define_generator
. The perturbation ranges for each dataset and generator are specified in networks/perturb_settings.py
.networks/domain_classifiers.py
to add the classifier in domain_classifiers.define_classifier
We thank the authors of these repositories:
If you use this code for your research, please cite our paper:
@inproceedings{chai2021ensembling,
title={Ensembling with Deep Generative Views.},
author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
booktitle={CVPR},
year={2021}
}