DrMMZ / ResFPN

Model ensemble: ResNet + FPN, and Focal Loss in TensorFlow2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ResFPN

This is an implementation of ResFPN on Python 3 and TensorFlow 2. The model classifies images by ensembling predictions from Residual Network (ResNet) and Feature Pyramid Network (FPN), and can be trained by minimizing focal loss.

The repository includes:

  • source code of ResFPN built on ResNet50/101 and FPN;
  • source code of focal loss (generalize to multi-class, with class balancing parameter); and
  • jupyter notebook demonstration using ResFPN in training, evaluation and visualization on the tf_flowers and COVIDx dataset. Below are example classifications on the tf_flowers dataset randomly selected from un-trained images.

Requirements

python 3.7.9, tensorflow 2.3.1, matplotlib 3.3.4 and numpy 1.19.2

Updates

  • 07/04/2021: Add synchronized SGD over multiple GPUs training, and some callbacks such as CSVLogger, ModelCheckpoint and ReduceLROnPlateau. Finally, modify the functions resnet_fpn.select_top() and resnet_fpn.predict() to have ability to visualize the predictions. A new notebook on the tf_flower dataset are presented as a demonstration.
  • 05/10/2021: Add Focal Loss implementation and some corresponding changes in ResFPN are made, see the model folder for details. Roughly speaking, focal loss can address class imbalance problem by removing easy examples during training. We present experimental results on the COVIDx dataset, see the tutorial folder.

About

Model ensemble: ResNet + FPN, and Focal Loss in TensorFlow2

License:MIT License


Languages

Language:Jupyter Notebook 97.1%Language:Python 2.9%