karpathy / llm.c

LLM training in simple, raw C/CUDA

Repository from Github https://github.comkarpathy/llm.cRepository from Github https://github.comkarpathy/llm.c

bug: something goes wrong at larger batch sizes

karpathy opened this issue · comments

There's some bug I have difficulty tracking down today and I'm going to give up for tonight and try again tomorrow.

Reproduction:

./train_gpt2cu -b 12

launches the job with batch size 12. On my 40GB GPU I see this takes up 20GB of memory. But launching with:

./train_gpt2cu -b 16

throws an error and crashes:

an illegal memory access was encountered

There's certainly space on the device (since 12 was just 20GB) so something is going very wrong. I tried:

CUDA_LAUNCH_BLOCKING=1 ./train_gpt2cu -b 16

to synchronize after every kernel, which points to this line in attention_backward.

// [cuBLAS ERROR]: 14 train_gpt2.cu 1020
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, scratch, HS, T * HS, att, T, T * T, &zero, dv, HS, T * HS, B * NH));

But it seems unlikely that it's cublas that is wrong, so we're just passing it bad memory perhaps, or... not sure.

I also tried NVIDIA compute sanitizer:

compute-sanitizer ./train_gpt2cu -b 16

which gives a massive amount of information, but also seems to implicate the same line:

========= Invalid __global__ read of size 16 bytes
=========     at 0x9b0 in void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_64x64_16x6_nt_align4>(T1::Params)
=========     by thread (31,0,0) in block (3,0,0)
=========     Address 0x7faad5183370 is out of bounds
=========     and is 5484563600 bytes before the nearest allocation at 0x7fac1c000000 of size 1056964608 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
...
=========                in /lib/x86_64-linux-gnu/libcublas.so.11
=========     Host Frame:attention_backward(float*, float*, float*, float*, float*, float const*, float const*, float const*, int, int, int, int) [0x10f94]
=========                in /home/ubuntu/llm.c/./train_gpt2cu
=========     Host Frame:gpt2_backward(GPT2*) [0x12f34]
=========                in /home/ubuntu/llm.c/./train_gpt2cu
=========     Host Frame:main [0xb7cb]
=========                in /home/ubuntu/llm.c/./train_gpt2cu
=========     Host Frame:../csu/libc-start.c:342:__libc_start_main [0x24083]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:_start [0xd58e]
=========                in /home/ubuntu/llm.c/./train_gpt2cu

I'm going to wrap up for today but there's almost certainly something wrong.

You're sure the full 40GB is free according to nvidia-smi? Happened to me that other things, even my browser, was taking up too much space.

Yes I'm looking closely at nvidia-smi and I'm not using a workstation, it's a cloud machine where the GPUs aren't doing anything else except for llm.c

This turns out to be embarrassingly simple... (l * B * NH * T * T) > 2^31 :(

Easy to fix, we just need to turn those variables into size_t for both forward & backward so the calculation happens in 64-bit (except for "l" which decrements in the backward loop, so you'd get an infinite loop if it's unsigned like size_t).

:'( I was looking exactly for this kind of a problem but didn't spot it. Good catch!

May be worth changing other variables into size_t due to overflow risk, see #155.

fixed here afaik #259