google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU cannot do simple arithmetic!

ayaka14732 opened this issue · comments

commented

I am trying to do simple matrix multiplication on TPU, but it gives a wrong result:

import jax.numpy as np
import numpy as onp

# On CPU
x = onp.array([[0.3744, 0.1656],
               [0.4707, 0.1663]])
y = onp.array([[0.3946, 0.1186],
               [0.1569, 0.3145]])
z = onp.dot(x, y)

# On TPU
x_ = np.asarray(x)
y_ = np.asarray(y)
z_ = np.dot(x_, y_)

print('JAX device:', x_.device())

# Compare
print('CPU result:', z)
print('TPU result:', z_)
assert np.allclose(z, z_)

Output:

JAX device: TPU_0(process=0,(0,0,0,0))
CPU result: [[0.17372088 0.09648504]
 [0.21183069 0.10812637]]
TPU result: [[0.17405128 0.09669876]
 [0.21180916 0.10805416]]
Traceback (most recent call last):
  File "/home/ayaka/main.py", line 21, in <module>
    assert np.allclose(z, z_)
AssertionError

Manual calculation:

0.3744 * 0.3946 + 0.1656 * 0.1569 = 0.13732088

So the result on CPU is correct, while the result on TPU is wrong.

Library versions:

jax                  0.3.4
jaxlib               0.3.2
libtpu-nightly       0.1.dev20220315

There are two "sharp bits" to be aware of here.

1. float64 - Your NumPy example is running with f64 inputs/outputs but the JAX version is f32. See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

2. Compute precision - XLA supports computing dot product in various precisions. JAX's default precision is optimised for performance and thus has relatively low precision (see #9952). This is typically OK for NN training, but can be the wrong choice for other applications.

You can adjust the precision globally using:

jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)

Or you can do so on a per operation basis:

c = jnp.dot(a, b, precision=jax.lax.Precision.HIGHEST)

Here is a copy of your example using HIGHEST precision on TPU (with f32 inputs) and getting a very close result compared to CPU: https://colab.research.google.com/gist/tomhennigan/69844409e46fd71267acf7479e2ba7f4/example-of-changing-default-precision-in-jax.ipynb

commented

Thank you for the explanation!

I am still wondering:

1. Are there any research indicating low precision will not affect the model performance? As deep learning models are growing larger and larger, I am thinking that different precision may result to totally different output after many layers of operations.

2. How much training time can I save when using lower precision?

Are there any research indicating low precision will not affect the model performance? As deep learning models are growing larger and larger, I am thinking that different precision may result to totally different output after many layers of operations.

It is important to note that while the default precision is low it is deterministic, so if you train a model in low precision and do inference on that trained model in low precision you should get the expected answer.

For very large models it is typical to drop the precision, because you need to compute so many floating point operations to train these models that the improvement in performance very significant on training time.

For Gopher (a large language model from DeepMind) we talk about low precision training (even lower than f32 defaults in JAX) with bfloat16 in Appendix C.2 of our paper https://arxiv.org/pdf/2112.11446.pdf.

How much training time can I save when using lower precision?

Typically accelerators have special hardware ("tensor cores") for half precision (e.g. bf16, f16) compute and you can expect computations to run somewhere between 5-10x faster than full precision f32 computations.

JAX's default precision for f32 dot product means the actual computation is done in bf16 on the TPU, so the performance improvement is significant vs. Precision.HIGH or Precision.HIGHEST.

commented

Thank you for the detailed explanation!

I think we can close this issue, but please let me know if not!