Cannot store customized adjoint function in a separate file
xuan-li opened this issue · comments
To reproduce:
adj.py
:
import warp as wp
@wp.func
def overload_fn(x: float, y: float):
return x * 3.0 + y / 3.0, y**2.5
@wp.func_grad(overload_fn)
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
wp.adjoint[y] += y * adj_ret1 * 3.0
if __name__ == "__main__":
wp.init()
@wp.kernel
def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)):
tid = wp.tid()
overload_fn(x[tid], y[tid])
x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32)
y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32)
wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])
test_adj.py
:
wp.init()
@wp.kernel
def overload_kernel(x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32)):
tid = wp.tid()
overload_fn(x[tid], y[tid])
x = wp.array([1.0, 2.0, 3.0], dtype=wp.float32)
y = wp.array([4.0, 5.0, 6.0], dtype=wp.float32)
wp.launch(overload_kernel, inputs=[x, y], dim=x.shape[0])
Execution of adj.py
is fine. But I will encounter the following error if I execute test_adj.py
:
Warp NVRTC compilation error 6: NVRTC_ERROR_COMPILATION (/buildAgent/work/a9ae500d09a78409/warp/native/warp.cu:1674)
default_program(122): error: identifier "adj_overload_fn" is undefined
Thanks for reporting this! This should also be fixed in the next release.
This should now be fixed.