Grad dtype issue when interfacing with PyTorch
ClemensSchwarke opened this issue · comments
Hi,
I noticed that in version 0.10, the wp.from_torch()
function leads to a dtype error when the original Torch tensor requires_grad=True
and has grad=None
. Warp is then creating a torch.zeros()
array and assigns it to the .grad
attribute of the Warp array by again calling wp.from_torch()
. This leads to an error if a certain dype is specified in the original call.
E.g. wp.from_torch(x, dtype=wp.vec3)
leads to ValueError: The given gradient array is incompatible.
Changing line 161 in torch.py from grad = from_torch(t.grad)
to grad = from_torch(t.grad, dtype=dtype)
would fix the issue I think.
It is also possible that I am using the provided functionality incorrectly, in that case I would be happy about any hints.
Thanks!
I also encountered this issue. Here is my workaround.
x = torch.tensor(
[[1, 1, 1, 1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
x.grad = torch.zeros_like(
x,
requires_grad=False,
)
requires_grad = x.requires_grad
if x.grad is not None:
grad = wp.from_torch(x.grad, dtype=wp.spatial_vector, requires_grad=False)
else:
grad = None
x_warp = wp.from_torch(
x,
dtype=wp.spatial_vector,
requires_grad=requires_grad,
grad=grad,
)
@ClemensSchwarke thanks for reporting and suggesting the fix! This issue is already fixed and will appear in the next release.