NVIDIA / warp

A Python framework for high performance GPU simulation and graphics

Home Page:https://nvidia.github.io/warp/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adjoint of Matmul

moradza opened this issue · comments

The current matmul and batched_matmul implementations break the gradient flow, as an example, in line

ctypes.c_void_p(a.ptr),

one should change a.ptr with ctypes.c_void_p(adj_a.ptr) and beta should set to 1. In the current implementation, old gradient of a is overwritten by matmul call, similar issue holds for b, c, and d.

Hi @moradza, thanks for the report - @daedalus5 can you take a look?