mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Scatter to array created from dlpack fails due to spurious references

dvicini opened this issue · comments

Hi,

I have a usecase where I would like to use Dr.Jit to write to a pre-allocated array that I receive as a dlpack capsule. When I initialize a Dr.Jit array from dlpack and then subsequently scatter to it, Dr.Jit will instead write to a new array that is a copy of the previous one. This happens because jitc_var_scatter checks if the ref_count > 2, and if that is the case, creates a copy.

What seems to happen here is that there is a spurious Python object that gets created during initialization. If I insert a garbage collector call before calling scatter, the spurious reference is deleted and scatter can then indeed write correctly to the original array.

Here is a minimal reproducer:

import drjit as dr 
import gc 

a = dr.linspace(dr.llvm.Float, 0, 1, 16)
a_dlpack = a.__dlpack__()
b = dr.llvm.Float(a_dlpack)

c = dr.linspace(dr.llvm.Float, 0, 2, 16)

print("Old b", b.data_())
# gc.collect() # Uncommenting this solves the issue, but is slow
dr.scatter(b, c, dr.arange(dr.llvm.Int32, dr.width(a)))
print("New b", b.data_())

Output without gc.collect():

Old b 91565209873920
New b 91565227055808

Output with gc.collect():

Old b 91586695922496
New b 91586695922496

This is for the pre-nanobind version. I couldn't verify if this also happens on the nanobind branch, as the dlpack support there seems not yet complete (?). I just wanted to raise it here because this is maybe a bit of an unusual use of the dlpack interface, but important in our context (concretely, if you embed Mitsuba within Jax, it receives pre-allocated buffers from XLA)