Wrong gradient computation through loops
ClemensSchwarke opened this issue · comments
Hi,
when differentiating through a loop, the gradient seems to be wrong if the loop overwrites a variable. In my understanding, this is caused by the fact that the replay of the backward pass skips the loop. Therefore, putting the loop in a seperate function solves the issue, as the function is called again in the replay pass.
Looking forward to your opinions on this! (Unfortunately, I don't have time for a minimal example at the moment, but I am happy to clarify my issue if this post is confusing)
Cheers,
Clemens
Sorry for the late reply!
when differentiating through a loop, the gradient seems to be wrong if the loop overwrites a variable.
In general overwriting a variable in a loop makes it non-differentiable. There should be a warning produced about this.
If the iteration count is constant and small (by default up to 16), we attempt to unroll the loop, and this makes it differentiable.