google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Document use of Mistral

borisdayma opened this issue · comments

It looks like you already support Mistral, though maybe missing sliding window attention.

Would be great to:

Looks like this is actually available: https://github.com/google/maxtext/blob/main/end_to_end/test_mistral.sh

The only thing I had to do was replace tokenizer.mistral with tokenizer.model (is it a typo or did you rename it in your bucket?).
Also I chose to convert the bfloat16 weights to float32 instead of float16 which I think could bring some imprecision.

Can I ask what kind of TPU are you using for the test, @borisdayma? I do have available a v4-32 that I'd like to use to do continue pre-training on Llama2/Mistral 7B, but other frameworks seem sub-optimal so far to me.

It should work on a v3-8.
You can also try the decode.py function but for me it worked on the 7b models (gemma or mistral).

Amazing @borisdayma! We don't actually official support Mistral (we do support Llama and Gemma) but we're thrilled things are working for you!

Yeah your inference test of mistral is correct. I compared with transformers output and was getting the same.

I'm closing this issue because Mistral seems to already work well after further testing.