Multi-Label Image Classification of Chest X-Rays In Pytorch
Multi-Label Image Classification of the Chest X-Rays In Pytorch
NIH Chest X-ray Dataset is used for Multi-Label Disease Classification of of the Chest X-Rays.
There are a total of 15 classes (14 diseases, and one for ‘No findings’)
Images can be classified as “No findings” or one or more disease classes:
There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing.
Pretrained Resnet50 model is used for Transfer Learning on this new image dataset.
There is a choice of loss function
Following are the layers which are set to trainable-
fc
Terminal Code:
python main.py
A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the
losses_dict (a dictionary containing the following loses)
Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The ‘stage’ parameter can be passed from the terminal using the argument —stage STAGE
Terminal Code:
python main.py --resume --ckpt checkpoint_file.pth --stage 2
Training the model will create a models directory and will save the checkpoints in there.
A Saved Checkpoint needs to be loaded using the —ckpt argument and —test argument needs to be passed for activating the Test Mode
Terminal Code:
python main.py --test --ckpt checkpoint_file.pth
The model achieved the average ROC AUC Score of 0.73241 on all classes(excluding “No findings” class) after training in the following stages-