32 bit optimizer update error despite gradients being the same
Edenzzzz opened this issue · comments
System Info
A100 GPU, torch 2.1, cuda 12.1, bitsandbytes 0.43.1
Reproduction
The tensors to be loaded are zipped here:
grads.zip
import torch
import bitsandbytes.functional as F
low_rank_grad = torch.load("low_rank_grad.pt")
dist_low_rank_grad = torch.load("dist_low_rank_grad.pt")
def update(low_rank_grad):
p = torch.zeros_like(low_rank_grad, dtype=torch.float32, device="cuda")
p.grad = low_rank_grad
lr = 1e-2
state = {}
state["state1"] = torch.zeros_like(low_rank_grad, dtype=torch.float32, device="cuda")
state["state2"] = torch.zeros_like(low_rank_grad, dtype=torch.float32, device="cuda")
beta1 = 0.9
beta2 = 0.999
step = 1
args = None
eps = 1e-8
weight_decay = 1e-2
F.optimizer_update_32bit(
"adam",
p.grad,
p,
state["state1"],
beta1,
eps,
step,
lr,
state["state2"],
beta2,
weight_decay
)
return p
print(low_rank_grad.shape, dist_low_rank_grad.shape)
assert (low_rank_grad[:, :32] == dist_low_rank_grad).all()
# adam update step
low_rank_grad = update(low_rank_grad)
dist_low_rank_grad = update(dist_low_rank_grad)
low_rank_grad[:, :32] == dist_low_rank_grad
My result showing that most updates on the same grad chunk diverged
Expected behavior
This comes from adapting the Galore optimizer for Tensor parallel, when testing precision of the distributed and original optimizer.
Here the gradient is shared along dim 1 by tensor parallel, but the corresponding grad chunk clearly matches. However after the optim step the chunks are not exactly the same. I first doubted this is due to quantization statistics, but using 32 bit and disabling quantization stably leads to this bug.
@matthewdouglas @TimDettmers any insights? Thanks!
Hi @Edenzzzz,
Make sure that this chunk is contiguous as F.optimizer_update_32bit ultimately treats it as 1D.
dist_low_rank_grad = torch.load("dist_low_rank_grad.pt").contiguous()
I was able to reproduce your results, and after this change I believe I'm seeing the desired result.
RTX 3060, CUDA 12.4, torch==2.2.2+cu121, bitsandbytes==0.43.1
Thanks a lot! This worked. The non-contiguous tensor came from torch.chunk and torch.distributed.all_gather.
I wonder if that's due to the c++ kernel not considering the reshaped strides and assuming row-major format? I can file in a PR to make it contiguous if you feel that's helpful.
I wonder if that's due to the c++ kernel not considering the reshaped strides and assuming row-major format?
Yes, exactly. The C++ kernel assumes it's row-major and only knows the total number of elements.
I can file in a PR to make it contiguous if you feel that's helpful.
That seems like a reasonable check to me, so a PR to add that sounds good!