google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Use bfloat16 for eval

tbaker2 opened this issue · comments

I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:

MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]

The training status output includes the following output.

model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]

All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?

Tom

Any comments on this?