Use bfloat16 for eval
tbaker2 opened this issue · comments
I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:
MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]
The training status output includes the following output.
model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]
All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?
Tom
Any comments on this?