bohaohuang / mrs

Models for Remote Sensing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Model loading failed when training paralleled

waynehuu opened this issue · comments

An explanation & solution of batchnorm when training on multiple gpus:
https://github.com/dougsouza/pytorch-sync-batchnorm-example

An explanation & solution of batchnorm when training on multiple gpus:
https://github.com/dougsouza/pytorch-sync-batchnorm-example

I'll take a look at this, also here's the link for the model: https://drive.google.com/file/d/12Qr7SUhGTWugqJ9AvBEl4aDTDOvTEm-h/view?usp=sharing

#47 solves the optimizer issue when resume training a model

This hasn't been fixed yet.

The problem is not about the "module" prefix in multi-gpu state_dict keywords. Multi-gpu trained models can be loaded successfully programming wise but they don't perform as they should do. Probably due to batch normalization being calculated on separate device and is not synchronized across devices. I tested this last week and the previous optimizer fix doesn't solve it.

A quick fix would be:
model.encoder = nn.DataParallel(model.encoder)
model.decoder = nn.DataParallel(model.decoder)
network_utils.load(model, ckpt_dir, disable_parallel=True)
But l think this is due to the data parallel wrapping in the training process, let me investigate this a little bit

I believe #53 has fixed the issue

When doing the evaluation, please load the model via:
network_utils.load(model, ckpt_dir)
instead of:
network_utils.load(model, ckpt_dir, disable_parallel=True)

This way the framework will try to wrap the model with nn.DataParallel instead of create a matching pattern to load the weights

I have tried and it seems have fixed the issue. But feel free to reopen it if it does not solve your problem

Also, one down side of the current fix is that it might not be downward-compatible with previous multi-gpu trained models

The new method could not distribute memory across multiple gpus

8b23932 should've fixed this issue:

  1. encoder and decoder still need to be wrapped with DataParallel separately to enable memory distributing across gpus
  2. model attributes need to be forwarded by custom DataParallel class to avoid OOM error at inference after loading the model

When training in multiple gpus, model can only be loaded with gpu:0, not gpu 1. And most of the times still get OOM erros

288b2ef fixs this issue:
gpu loading error is solved by setting the primary device properly for dataparallel, OOM error seems like a cuda bug that occurs rarely