MPS support
joaolcguerreiro opened this issue · comments
Hi,
I was trying to ge the summary of a 3D model. However, I'm getting an error.
I am running:
summary(generator, input_size=(1, 1, 400, 400, 400), depth=1)
I believe it is because I'm using a macbook m1 and need this to work in MPS. Maybe something like this would work:
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
Thank you :)
Have you tried passing the device as a parameter?
summary(generator, input_size=(1, 1, 400, 400, 400), depth=1, device=torch.device("mps"))
If that doesn't work, try using input_data
instead and moving the model and input_data to the correct device.
I haven't tested yet, but I'm almost sure that will do the work :) Thank you! <3