google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Home Page:https://jaxopt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Debugging vanishing gradients in implicit fix-point differentiation

alessandrosimon opened this issue · comments

I have a neural network in flax that is basically a function expansion in a non-linear basis set, imagine

f(x, p) = p_1 * sin(x) + p_2 * exp(x) + ...

and I want to find parameters such that for a given X the function f( . , p) has a fix-point at X. The idea was to use
Anderson Acceleration together with implicit differentiation to tune the parameters. The problem is that the gradient
of the FP z = f(z, p) wrt p is zero. I checked the output of the solver (verbose=True) and it successfully finds a solution in less than maxiter iterations.

The normal gradient operation through the network seems to work just fine, because if I take the returned fix-point solution and do one further iteration manually, I do get a non-vanishing gradient wrt the parameters. I could look at the generated jaxpr but the procedure is quite long/complicated so I don't think it would help much.

Do you have a minimal example?

"Unfortunately" the method seems to work for a small test network. I think the problem lies in some later transformation that I apply to the output of the described layer. What I don't understand is that the gradient is non-zero for a simple 'single step' evaluation but vanishes as soon as a second iteration is done (setting maxiter=2 for example).

OK, I think I know what the problem was. During the backward pass when the linear system of the inverse Jacobian is solved (by default with solve_cg in linear_solve.py) it doesn't find any solution and just returns the initial starting value, which is a zero vector. I guess it would be nice if there was some kind of error message in the case of non-convergence.