Convert Gemma weights with scan layers
borisdayma opened this issue · comments
Boris Dayma commented
Hi,
It would be nice to be able to convert the Gemma checkpoints to support scan layers.
This will allow faster compilation for training & inference.
Thanks
Rafi Witten commented
Gemma comes out scanned by default right? And when it is time to unscan use this one:
https://github.com/google/maxtext/blob/main/MaxText/generate_param_only_checkpoint.py
Boris Dayma commented
Oh thanks!
I got confused initially by this line with scan_layers=false
:
After inspecting the weights I can see they’re scanned.
Just curious why it’s scanned over param_scan_axis: 1
instead of just 0
.