TinyJit does not update bound values in symbolic shape
chenyuxyz opened this issue · comments
chenyu commented
repro
from tinygrad import Tensor, TinyJit, Variable
@TinyJit
def f(t): return t + 1
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
t = f(Tensor.rand(i).reshape(vi))
print(f"{t.shape=}")
outputs
t.shape=(Variable('i', 1, 10).bind(1),)
t.shape=(Variable('i', 1, 10).bind(2),)
t.shape=(Variable('i', 1, 10).bind(2),)
t.shape=(Variable('i', 1, 10).bind(2),)
without TinyJit, it outputs
t.shape=(Variable('i', 1, 10).bind(1),)
t.shape=(Variable('i', 1, 10).bind(2),)
t.shape=(Variable('i', 1, 10).bind(3),)
t.shape=(Variable('i', 1, 10).bind(4),)
we can either update the shape after jit call, or enforce that Variable can only be registered inside the jitted function