google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improved docs for (default) matmul precision

sanchit-gandhi opened this issue · comments

  1. It is somewhat unintuitive that the default matmul precision is bfloat16 on TPU, especially for users coming from PyTorch/GPU where the default precision is float32. Information regarding the default matrix multiplication precision on TPUs is extremely difficult to find. There is a short section on the README.md within the cloud TPU Colab folder of the JAX repo: https://github.com/google/jax/tree/main/cloud_tpu_colabs#bfloat16-dtype However, this is somewhat unclear, as it references 'MXUs' without any explanation of what this abbreviation means, and only highlights how the default precision can be changed manually on a op-by-op basis by setting precision=jax.lax.Precision.XXX. This gives the impression that in order to change the TPU precision to float32, one must insert the key-word argument precision=jax.lax.Precision.HIGHEST for every jax.numpy operation in one's script.

  2. It is difficult to find how the default precision can be changed. Performing matmul operations in the default bfloat16 precision can lead to undesirable results. At Hugging Face, we're constantly running into problems with the default fast-speed low precision TPU default, as shown here for example: huggingface/transformers#15754
    In the case of changing the default matmul precision, the docs do make mention to the default matmul precision context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html However, they do not explicitly state how one can use this context manager to change the default matmul precision (for instance with an example). It's hard to know from the docs that you have to write your code under the context manager as follows:

with jax.default_matmul_precision('float32'):   # or 'bfloat16' for lowest
  ... = foo(...)

The docs also brush over three additional methods for changing the default matmul precision, highlighted brilliantly in this PR: #6143 (comment) These three methods require no change to one's actual script, just the inclusion of a shell/command line flag or a JAX config change, and are arguably much easier to use and less obtrusive.

It would be great if the default matmul precisions for CPU/GPU/TPU were documented, along with what bfloat16, tensorfloat16, float32 precision actually mean for matmul precision in terms of number of passes. It would also be super helpful if all four methods for manipulating the default precision were added to the docs with short examples on how to use them, as done in the aforementioned PR.

Note that default precision for matrix-matrix multiplication is actually now tensorfloat32 on recent Nvidia GPUs: #14022