google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Convert Gemma weights with scan layers

borisdayma opened this issue · comments

Hi,

It would be nice to be able to convert the Gemma checkpoints to support scan layers.
This will allow faster compilation for training & inference.

Thanks

Gemma comes out scanned by default right? And when it is time to unscan use this one:
https://github.com/google/maxtext/blob/main/MaxText/generate_param_only_checkpoint.py

Oh thanks!
I got confused initially by this line with scan_layers=false:

python MaxText/maxengine_server.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false

After inspecting the weights I can see they’re scanned.
Just curious why it’s scanned over param_scan_axis: 1 instead of just 0.