google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Autodiff with `numpy.linalg.qr` differs from autodiff with `scipy.linalg.qr`

pnkraemer opened this issue · comments

Hi all!

I am a little confused by the different implementations of QR decompositions in jax.scipy.linalg und jax.numpy.linalg.
If called with specific inputs, one admits a JVP, and the other one does not.

It is not unlikely that I missed something, but I suspect the reason is the following.
Both jax.scipy.linalg and jax.numpy.linalg point to jax.lax.linalg.qr, but differently: if one uses for example mode="r", then the numpy version leads to full_matrix=False

full_matrices = False

and the scipy version to full_matrix=True

full_matrices = True

(Admittedly, it is probably hard to see from the permalinks... )

This is an issue for me, because the Jacobian-vector product of jax.lax.linalg.qr raises an error whenever full_matrices is True

if full_matrices or m < n:

and thus I have to be really careful with Jacobians in my code, which does rely quite heavily on QR-decompositions.

Minimal working example:

import jax
import jax.numpy as jnp
import jax.scipy as jsp

key = jax.random.PRNGKey(seed=2)
key, subkey = jax.random.split(key)
x = jax.random.normal(key, shape=(2, 2))
dx = jax.random.normal(subkey, shape=(2, 2))

# The values coincide
r_np = jnp.linalg.qr(x, mode="r")
r_sp = jsp.linalg.qr(x, mode="r")
assert jnp.allclose(r_np, r_sp)

# The numpy jvp seems to work, at least kind of...
_, dr = jax.jvp(lambda s: jnp.linalg.qr(s, mode="r"), (x,), (dx,))
dt = 1e-6
dr_ = (jnp.linalg.qr(x + dt * dx, mode="r") - jnp.linalg.qr(x, mode="r")) / dt
assert jnp.allclose(dr, dr_, atol=1e-3, rtol=1e-3) , (dr-dr_)  # with tol=1e-5 it does not work, but at least there are values...

# The scipy jvp fails:
# NotImplementedError: Unimplemented case of QR decomposition derivative
_ = jax.jvp(lambda s: jsp.linalg.qr(s, mode="r"), (x,), (dx,))

From the snippet, it is not 100% clear to me whether the JVP is even correct, given that I had to reduce the precision in the second assertion. But the results are also too similar for something to be drastically broken, so it should be fine?!

There have been some potentially related issues regarding QR decompositions and their differentiability:

  • #2863
  • #433
  • #8542 (only vaguely related, but it does mention QR decompositions!)

Am I missing something? Thanks for your help! Let me know if you need more info from me :)

Hi - thanks for the question. First of all, the difference in the jax.scipy.qr and jax.numpy.qr defaults comes from the APIs they are wrapping. Scipy defaults to mode='full' (doc) while numpy defaults to mode='reduced' (doc)

I'm not entirely familiar with the qr JVP rule, but my understanding is that the reason it is undefined for the full matrix output is becuse the gradient itself is mathematically ill-defined in this case: full output will construct additional bases spanning the remaining vector space, and there are infinitely many valid bases for these and thus how they change with respect to a perturbation of the input is ill-defined.

It may be that we are too strict about erroring in the jvp on full_matrices=True even when the shapes of the inputs mean that the result is equivalen to full_matrices=False, and thus it should be differentiable.

Finally, I suspect the reason your JVP comparison requires such loose tolerance is probably not because of inaccuracy of the automatic gradient, but because of inaccuracy in the finite difference gradient to which you're comparing it.

Hi, thanks for the immediate answer!

My question is less about the defaults or why there are errors for specific QR-decomposition-modes, but more about why both APIs point to different versions of jax.lax.linalg.qr when called with mode="full" (or mode="r", as in the example above). I understand that different QR formats may have different differentiability assumptions, but that seems to be different, or am I missing something? I apologise if this wasn't clear from my initial question.

Finally, I suspect the reason your JVP comparison requires such loose tolerance is probably not because of inaccuracy of the automatic gradient, but because of inaccuracy in the finite-difference gradient to which you're comparing it.

True, that could be the case. However. it is not resolved with different dt values (I was a bit too lazy to try more sophisticated schemes, to be perfectly honest). I don't think it is a major part of this issue, though :)

I see - thanks for the clarification. I believe the behavior of the numpy.linalg.qr and scipy.linalg.qr functions match that of the corresponding numpy/scipy implementation, which is the intent. In particular, the original numpy implementation has mode='full' and mode='reduced' return the same thing, though the behavior is marked as deprecated: https://github.com/numpy/numpy/blob/65a701fd40d5ab3b772131daf45679d6ecf3d721/numpy/linalg/linalg.py#L911-L917

I can't say I understand the reason for this, but the jax.numpy.linalg.qr result does seem consistent with the numpy.linalg.qr result in this case.

If you find any case where the behavior of jax.numpy.linalg.qr or jax.scipy.linalg.qr differs from that of the corresponding numpy/scipy routines, then it is a bug that we should address. But differences between mode conventions used within numpy and scipy are not something we can control in JAX.

It looks like there is one inconsistency: scipy.linalg.qr with mode="r" returns a length-1 tuple containing the resulting array, where jax.scipy.linalg.qr just returns the array result itself. We should address this.

Oh wow, it seems that I haven't noticed that scipy's mode=r and numpy's mode=r return differently shaped arrays!
Thank you for clarifying this :)

From my perspective, the issue can be closed now :)

Thanks!