NVIDIA / warp

A Python framework for high performance GPU simulation and graphics

Home Page:https://nvidia.github.io/warp/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Grad dtype issue when interfacing with PyTorch

ClemensSchwarke opened this issue · comments

Hi,
I noticed that in version 0.10, the wp.from_torch() function leads to a dtype error when the original Torch tensor requires_grad=True and has grad=None. Warp is then creating a torch.zeros() array and assigns it to the .grad attribute of the Warp array by again calling wp.from_torch(). This leads to an error if a certain dype is specified in the original call.

E.g. wp.from_torch(x, dtype=wp.vec3) leads to ValueError: The given gradient array is incompatible.

Changing line 161 in torch.py from grad = from_torch(t.grad) to grad = from_torch(t.grad, dtype=dtype) would fix the issue I think.

It is also possible that I am using the provided functionality incorrectly, in that case I would be happy about any hints.
Thanks!

I also encountered this issue. Here is my workaround.

x = torch.tensor(
    [[1, 1, 1, 1, 1, 1]],
    dtype=torch.float32,
    requires_grad=True,
)

x.grad = torch.zeros_like(
    x,
    requires_grad=False,
)

requires_grad = x.requires_grad
if x.grad is not None:
    grad = wp.from_torch(x.grad, dtype=wp.spatial_vector, requires_grad=False)
else:
    grad = None

x_warp = wp.from_torch(
    x,
    dtype=wp.spatial_vector,
    requires_grad=requires_grad,
    grad=grad,
)

@ClemensSchwarke thanks for reporting and suggesting the fix! This issue is already fixed and will appear in the next release.