项目作者: ramprs

项目描述 :
[ECCV 2018] code for Choose Your Neuron: Incorporating Domain Knowledge Through Neuron Importance
高级语言: Python
项目地址: git://github.com/ramprs/neuron-importance-zsl.git
创建时间: 2018-07-26T01:59:06Z



NIWT: Neuron-Importance aware Weight Transfer

Code for the ECCV’18 paper

[Choose-Your-Neuron: Incorporating Domain Knowledge into Deep Networks through Neuron-Importance]
Ramprasaath R. Selvaraju, Prithvijit Chattopadhyay, Mohammed Elhoseiny, Tilak Sharma, Dhruv Batra, Devi Parikh, Stefan Lee



This codebase assumes that you have installed Tensorflow. If not, please follow installation instructions from here.
Download data and pretrained checkpoints using sh download.sh and make sure the paths in the arg_config json files are correct.
You may also need to create an imagenet_files.pkl which contains a list of (atleast) 3000 randomly sampled imagenet image paths.

Train a Generalized Zero Shot Learning model on AWA2 and CUB (class-level attributes)

  1. python alpha2w.py --config_json arg_configs/vgg16_config_AWA.json
  2. python alpha2w.py --config_json arg_configs/resnet_config_AWA.json
  3. python alpha2w.py --config_json arg_configs/vgg16_config_CUB.json
  4. python alpha2w.py --config_json arg_configs/resnet_config_CUB.json

Train a Generalized Zero Shot Learning model on CUB with captions (class-level)

  1. python alpha2w.py --config_json arg_configs/vgg16_config_CUB_captions.json
  2. python alpha2w.py --config_json arg_configs/resnet_config_CUB_captions.json

Train a GZSL classifier from scratch

Pretrain base model on dataset

To do this, we first finetune the base model (vgg16 or resnet_v1) on a seen class images.

  1. cd seen_pretraining/
  2. sh cnn_finetune.sh

Extract Neuron Importances (alphas)

Change the ckpt_path from the config_json files to the trained checkpoint (obtained from above)
Extract Neuron-Importances (alphas) from the finetuned model.

  1. sh alpha_extraction.sh

Domain knowledge to Neuron Importance:

Here we learn a transformation from domain knowledge (say attributes) to network neuron importances (alphas)

  1. cd ..
  2. python mod2alpha.py --config_json arg_configs/vgg16_config_AWA.json
  3. python mod2alpha.py --config_json arg_configs/resnet_config_AWA.json
  4. python mod2alpha.py --config_json arg_configs/vgg16_config_CUB.json
  5. python mod2alpha.py --config_json arg_configs/resnet_config_CUB.json

Neuron Importance of unseen classes to classifier weights of unseen classes (training a GZSL model)

  1. python alpha2w.py --config_json arg_configs/vgg16_config_AWA.json
  2. python alpha2w.py --config_json arg_configs/resnet_config_AWA.json
  3. python alpha2w.py --config_json arg_configs/vgg16_config_CUB.json
  4. python alpha2w.py --config_json arg_configs/resnet_config_CUB.json