Multi-gpu setting bug?
priancho opened this issue · comments
Han-Cheol Cho commented
Description
When a machine is equipped with more than one GPU, the current code sets device_ids
parameter of torch.nn.DataParallel()
method to [0, 1, 2]
.
# current code:
# https://github.com/Tramac/Fast-SCNN-pytorch/blob/master/train.py#L90
if torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1, 2])
self.model.to(args.device)
Possible solution
The following code change will make it work for machines with 2, 4, .. GPUs.
if torch.cuda.device_count() > 1:
device_ids = list(range(torch.cuda.device_count()))
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
self.model.to(args.device)