TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MPS support

joaolcguerreiro opened this issue · comments

Hi,

I was trying to ge the summary of a 3D model. However, I'm getting an error.

I am running:
summary(generator, input_size=(1, 1, 400, 400, 400), depth=1)

I believe it is because I'm using a macbook m1 and need this to work in MPS. Maybe something like this would work:
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))

Thank you :)

Have you tried passing the device as a parameter?

summary(generator, input_size=(1, 1, 400, 400, 400), depth=1, device=torch.device("mps"))

If that doesn't work, try using input_data instead and moving the model and input_data to the correct device.

I haven't tested yet, but I'm almost sure that will do the work :) Thank you! <3