项目作者: ceshine

项目描述 :
PyTorch EfficientDet Solution for Global Wheat Detection Challenge
高级语言: Jupyter Notebook
项目地址: git://github.com/ceshine/global-wheat-detection.git
创建时间: 2021-07-06T13:39:39Z
项目社区:https://github.com/ceshine/global-wheat-detection

开源协议:MIT License

下载


PyTorch EfficientDet Solution for Global Wheat Detection Challenge

  1. Training notebook (on Kaggle)
  2. Inference notebook: Single model; Ensemble

See wheat/config.py for hyper-parameters and system configurations.

The best mAP score I’m able to get is 0.6167 (Private) / 0.7084 (Public) with a D4 model trained on 768x768 resolution (using a single P100 GPU).

Requirements

  1. torch>=1.7.0
  2. pytorch-lightning>=1.3.6
  3. pytorch-lightning-spells==0.0.3
  4. efficientdet-pytorch==0.2.4

Note: You’ll need to use my fork of efficientdet-pytorch to use the O2 level of Apex AMP.

Instructions

Resizing images:

  1. python scripts/resize_images.py 512 --root data/

Training (pass --help for more information):

  1. python -m wheat.train data/512 --epochs 10 --grad-accu 4 --batch-size 8 --arch tf_efficientdet_d3 --fold 0 --mixup 24 --mosaic-p 0.5

Evaluation (pass --help for more information):

  1. python -m wheat.eval data/512 export/tf_efficientdet_d3-mosaic-mixup-fold0.pth --batch-size 8 --arch tf_efficientdet_d3 --fold 0