pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Mistral support

Nikita-Sherstnev opened this issue · comments

Would it be hard to adapt this code for Mistral? I tried open orca version and set vocab_size in config to 32002. But shapes did not match:

File "/experiments/dev/nsherstnev/gpt-fast/scripts/convert_hf_checkpoint.py", line 61, in permute
    w.view(n_head, 2, config.head_dim // 2, dim)
RuntimeError: shape '[32, 2, 64, 4096]' is invalid for input of size 4194304

you'll need to change some more configuration params (e.g. n_local_heads should be 8)

I'd copy them from here https://huggingface.co/docs/transformers/main/model_doc/mistral#transformers.MistralConfig

Done in #116 The issue can be closed now.