PC-DARTS (PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search, published in ICLR 2020) implemented in Tensorflow 2.0+. This is an unofficial implementation.
PC-DARTS (PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search, published in ICLR 2020) implemented in Tensorflow 2.0+. This is an unofficial implementation.
PC-DARTS is a memory efficient differentiable architecture search method, which can be trained with a larger batch size and, consequently, enjoys both faster speed and higher training stability. Experimental results achieve an error rate of 2.57% on CIFAR10 with merely 0.1 GPU-days for architecture search.
Original Paper: Arxiv OpenReview
Offical Implementation: PyTorch
Create a new python virtual environment by Anaconda or just use pip in your python environment and then clone this repository as following.
git clone https://github.com/peteryuX/pcdarts-tf2.git
cd pcdarts-tf2
conda env create -f environment.yml
conda activate pcdarts-tf2
pip install -r requirements.txt
You can modify your own dataset path or other settings of model in ./configs/*.yaml for training and testing, which would like below.
# general setting
batch_size: 128
input_size: 32
init_channels: 36
layers: 20
num_classes: 10
auxiliary_weight: 0.4
drop_path_prob: 0.3
arch: PCDARTS
sub_name: 'pcdarts_cifar10'
using_normalize: True
# training dataset
dataset_len: 50000 # number of training samples
using_crop: True
using_flip: True
using_cutout: True
cutout_length: 16
# training setting
epoch: 600
init_lr: 0.025
lr_min: 0.0
momentum: 0.9
weights_decay: !!float 3e-4
grad_clip: 5.0
val_steps: 1000
save_steps: 1000
Note:
sub_name
is the name of outputs directory used in checkpoints and logs folder. (make sure of setting it unique to other models)save_steps
is the number interval steps of saving checkpoint file.Step1: Search cell architecture on CIFAR-10 using small proxy model.
python train_search.py --cfg_path="./configs/pcdarts_cifar10_search.yaml" --gpu=0
Note:
--gpu
is used to choose the id of your avaliable GPU devices with CUDA_VISIBLE_DEVICES
system varaible.tensorboard --logdir=./logs/
“. My logs can be found from search_log and full_train_log.python ./modules/lr_scheduler.py
“.python ./dataset_checker.py
“.Step2: After the searching completed, you can find the result genotypes in ./logs/{sub_name}/search_arch_genotype.py
. Open it and copy the latest genotype into the ./modules/genotypes.py, which will be used for further training later. The genotype like bellow:
TheNameYouWantToCall = Genotype(
normal=[
('sep_conv_3x3', 1),
('skip_connect', 0),
('sep_conv_3x3', 0),
('dil_conv_3x3', 1),
('sep_conv_5x5', 0),
('sep_conv_3x3', 1),
('avg_pool_3x3', 0),
('dil_conv_3x3', 1)],
normal_concat=range(2, 6),
reduce=[
('sep_conv_5x5', 1),
('max_pool_3x3', 0),
('sep_conv_5x5', 1),
('sep_conv_5x5', 2),
('sep_conv_3x3', 0),
('sep_conv_3x3', 3),
('sep_conv_3x3', 1),
('sep_conv_3x3', 2)],
reduce_concat=range(2, 6))
Note:
python ./visualize_genotype.py TheNameYouWantToCall
“.
Step1: Make sure that you already modifed the flag arch
in ./configs/pcdarts_cifar10.yaml to match the genotype you want to use in ./modules/genotypes.py.
Note:
arch
(PCDARTS
) is the genotype proposed by official paper. You can train this model by yourself, or use dowload it from BenchmarkModels.Step2: Train the full-sized model on CIFAR-10 with specific genotype.
python train.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0
To evaluate the full-sized model with the corresponding cfg file on the testing dataset. You can also download my trained model for testing from Models without training it yourself, which default arch
(PCDARTS
) is the best cell proposed in paper.
python test.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0
Method | Search Method | Params(M) | Test Error(%) | Search-Cost(GPU-days) |
---|---|---|---|---|
NASNet-A | RL | 3.3 | 2.65 | 1800 |
AmoebaNet-B | Evolution | 2.8 | 2.55 | 3150 |
ENAS | RL | 4.6 | 2.89 | 0.5 |
DARTSV1 | gradient-based | 3.3 | 3.00 | 0.4 |
DARTSV2 | gradient-based | 3.3 | 2.76 | 1.0 |
SNAS | gradient-based | 2.8 | 2.85 | 1.5 |
PC-DARTS (official PyTorch version) | gradient-based | 3.63 | 2.57 | 0.1 |
PC-DARTS TF2 (paper architecture) | gradient-based | 3.63 | 2.73 | - |
PC-DARTS TF2 (searched by myself) | gradient-based | 3.56 | 2.88 | 0.12 |
Note:
Dowload these models bellow, then extract them into ./checkpoints/
for restoring.
Model Name | Config File | arch |
Download Link |
---|---|---|---|
PC-DARTS (CIFAR-10, paper architecture) | pcdarts_cifar10.yaml | PCDARTS |
GoogleDrive |
PC-DARTS (CIFAR-10, searched by myself) | pcdarts_cifar10_TF2.yaml | PCDARTS_TF2_SEARCH |
GoogleDrive |
Note:
arch
flag in it is matched with the genotypes name in ./modules/genotypes.py.Thanks for these source codes porviding me with knowledges to complete this repository.