kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Can "slim_model.py" work with "d_model" as 768?

leejason opened this issue · comments

I updated "6B_roto_256.json" with the following for trying a smaller model.

"d_model": 768

The pretraining works on one TPU v3-8, but the slimmed model after using "slim_model.py" produces gibberish results.

Why? Does "slim_model.py" work with "d_model: 4096" only? I don't think so but I find no clue after tracing source code for hours.

Thank you for some light.