stanford-crfm / mistral

Mistral: A strong, northwesterly wind: Framework for transparent and accessible large-scale language model training, built with Hugging Face 🤗 Transformers.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Allow finetuning of mistral models using the HuggingFace Flax LM classes

TheodoreGalanos opened this issue · comments

It would be amazing if we could load and finetune the models on TPUs using the flax LM classes in HF. In my experience, this makes the training and generation very straightforward on TPUs, along ofc with taking advantage of their compute.

I have tried to load a mistral checkpoint with the following code:
model = FlaxAutoModelForCausalLM.from_pretrained("alias/arwen-x21-checkpoint-400000", from_pt=True, pad_token_id=50256, )
This seems to work. The model loads, I can access its properties, and can even generate text.

However, once I try to fine tune it, using (more or less) the code here: https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py, it takes about 10mins to compile and then about 5mins for each step (for reference, in this should be 2mins and some seconds respectively got gpt2-medium).

Finally, it would be nice if the changes in mistral models were smh included when loading the model in HF (I am actually not 100% sure that does not happen). Specifically, I'm thinking of this line here:

scale_factor = 1 / ((float(v.size(-1)) ** 0.5) * self.layer_num)

Hope this makes sense. Thank you in advance!

Best,
Theodore.

Hey Theodore - so we're definitely working on pushing the Mistral-specific operation changes (like the one you mentioned) to Transformers proper, as a flag in the GPT-2 Model class. This should happen by the end of the week (or at least, we'll have a PR in transformers you can use!).

As for why the Flax code is running slower - that's super interesting, and I don't have a good answer! Could be some weird interaction between the way we handle the upcasting code and defaults in the run_clm_flax.py script. Would be great if you could do some digging (or create an issue/PR!) as we're not too familiar with Flax ourselves, otherwise, I'll take a look when I can!

Hello,

Bumping this real quick. I haven't checked in a while, so excuse me if this was done, but is it done? :)

Would love to finetune some mistral models on TPUs.