Gradient through multi-dimensional arrays
ClemensSchwarke opened this issue · comments
Hi,
I had some trouble figuring out that the gradient of, in my case, a 2d array is only backpropagated correctly if its elements are accessed with [a,b]
indexing. Using two slicing operators [a][b]
leads to adjoints being 0.
Minimal example:
import warp as wp
@wp.kernel
def test(
a: wp.array2d(dtype=wp.vec3),
b: wp.array2d(dtype=wp.vec3),
c: wp.array2d(dtype=wp.vec3),
):
tid = wp.tid()
c[tid][0] = a[tid][0] + b[tid][0]
wp.init()
tape = wp.Tape()
a = wp.full((1,1), value=1.0, dtype=wp.vec3, requires_grad=True)
b = wp.full((1,1), value=1.0, dtype=wp.vec3, requires_grad=True)
c = wp.zeros((1,1), dtype=wp.vec3, requires_grad=True)
with tape:
wp.launch(
kernel=test,
dim = 1,
inputs = [a, b],
outputs = [c],
)
c.grad = wp.full((1,1), value=1.0, dtype=wp.vec3)
tape.backward()
print(c.grad)
print(a.grad)
print(b.grad)
Output:
[[[1. 1. 1.]]]
[[[0. 0. 0.]]]
[[[0. 0. 0.]]]
Is this intended? Thanks in advance!
Clemens
Hi Clemens, this looks like a bug - we will look at supporting gradient propagation through slicing in the next release.
Thanks!
Miles
Sorry for the delay in getting to this. I'll have a look soon.