StarCoder2 3B SFT models give CUDA OOM on IFEval
lewtun opened this issue · comments
For some peculiar reason, I am getting CUDA OOM when evaluating an SFT of bigcode/starcoder2-3b
on ifeval
. Note this doesn't happen with the 15b models which suggests either a bug on the transformers
side, but I'm opening the issue first here in case you know if lighteval
handles these models differently or whether it's a user error on my part :)
Here's the command to reproduce:
accelerate launch --multi_gpu --num_processes=8 run_evals_accelerate.py \
--model_args "pretrained=HuggingFaceH4/starcoder2-3b-sft-v00-deploy" \
--use_chat_template \
--tasks "extended|ifeval|0|0" \
--custom_tasks "extended_tasks/ifeval/main.py" \
--override_batch_size 1 \
--output_dir "./scratch/evals"
Note you'll need transformers
from main
and the version I'm using is:
pip install 'transformers @ git+https://github.com/huggingface/transformers.git@0290ec19c901adc0f1230ebdccad11c40af026f5'
I'm also using lighteval
commit df21407d9f714bde9ecfb4dd8283afdc2150eec3
I've inspected the inputs / outputs and everything looks good until I hit one sample that seems to blow up the memory.
Edit: the OOM issue is also present in gsm8k
Hi ! Thanks for the details, I was able to reproduce the issue and look into it.
Seems like it is indeed a bug in transformers
(when using model parallel with 8GPUs it OOMs as well). lighteval
does not handle these model differently, it could have been the dtype which is not set in your case but setting it to float16
does not solve the issue.
Interestingly, removing --use_chat_template
makes it work for --max_samples 1
.
What is particularly interesting is that samples should be sorted by context size, so hitting a sample where it failed (but not at the start) would be a matter of the generation being too big.
There could be something in the chat template which encourages the model to continue writing, what's the chat template for this one?
This particular model uses ChatML and appears to follow the prompts OK when I look at the intermediate outputs. I'll see if I can repro the issue with pure transformers
code and open an issue there