NVIDIA / warp

A Python framework for high performance GPU simulation and graphics

Home Page:https://nvidia.github.io/warp/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Interoperability with PyTorch

TheDevilWillBeBee opened this issue · comments

Hi,

Thanks for creating this amazing framework. Can you provide a simple example of backpropagation where most of the computation (including the loss function) happens in PyTorch and Warp is only used to implement one of the modules.

Thank you in advance!

Check out the example here: https://github.com/NVIDIA/warp/blob/main/examples/example_sim_fk_grad_torch.py#L29

This example "injects" the warp functionality as an pytorch autograd function. By using this with .apply() as done here you can freely use this differentiable function inside any larger pytorch system that uses autograd and backprop.

Another Torch interop example was added in ceb79fc. Thanks Zach!