TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error with model parallelism

clessig opened this issue · comments

Hi,

I recently implemented model parallelism. With it, however, the code fails in

torchinfo.summary( self.model, input_data=[batch_data])

with "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! ". Running the model with test() and train() works without problem.

From the behavior it seems that torchinfo copies the network in the back to the same device? Is this the case? Is there a way to use torchinfo with model parallelism?

Thanks!

Yes, currently the input_tensor (or auto-generated tensor from input_shape) as well as the model are moved to whatever device the model is on, unless a device is given. However, this works by finding the device of the first parameter and then moving everything to it; fixing this issue will require a bit of a redesign.

I'll look into it. Could you provide a test-case? I'm not very familiar with doing model-parallelism on PyTorch, so that would really help :)

Hi Sebastian,

Thanks for the quick reply. To begin with, an error message would be nice--it took me an hour to track down that torchinfo might not support model parallelism.

I am a bit time-strapped at the moment. But any test-case that I could produce wouldn't look different than the official pytorch model parallelism toy example: https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html. So if you use this and add torchinfo you should be able to reproduce the issue.

Best,
Christian

This has been fixed, should go out in v1.8.0