sally20921 / ConSSL

PyTorch Implementation of SOTA SSL methods

Home Page:https://sites.google.com/snu.ac.kr/serileeproject00/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Code improvements

vfdev-5 opened this issue · comments

commented

Hi @sally20921

As discussed before, I create an issue here to simplify the code:

  1. Metrics. In Ignite we already have all necessary metrics. Please, take a look here:

Thus, this part of the code can be simplified (no need to introduce StatMetric):
https://github.com/sally20921/BYOL/blob/455e15e9c9375f8cb99fe133860dc91838092440/code/train.py#L29-L32

Another point is that, computing metrics during the training, i.e. model is updated on each batch, can be misleading as metric is averaged on predictions of different models. Please, see also here the footnote : https://pytorch.org/ignite/quickstart.html#id1

  1. Code structure. It would be nice to simplify the codebase structure and have maximum 4-5 files in the repository:
  • train.py = main script to train a model
  • evaluate.py = main script to run evaluation of a trained model
  • utils.py = various helper methods used by train.py and evaluate.py
  • dataflow.py = helper module to setup training/validation/test dataloaders
  • (optional) models.py = implementation of the models if necessary
  • (optional) losses.py = implementation of loss functions if necessary

what do you think ?

  1. Distributed training.

Using the latest API it is possible to write a code that can run on any number of devices.

commented

Thanks for the update @sally20921 !

I would like to test the code on my side. Which dataset you are using for debuging and which datasets you think we'll need for checking implementation correctness ? I suppose that we'll need ImageNet-1k for a large scale pretraining as they do in the paper...

I'm still hesitating about the repository's structure and unfortunately find it very complicated to understand what is done and how it is implemented. Would you like to improve it ?

I have implemented CIFAR10, CIFAR100, STL10, ImageNet so you can test using any of the mentioned above!
I will try to improve the code like you mentioned on the issue=)
I will be contacting you soon😀