TimDettmers / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.

Home Page:https://huggingface.co/docs/bitsandbytes/main/en/index

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: Input tensors need to be on the same GPU, but found the following tensor and device combinations

rlanday opened this issue · comments

System Info

I was using a Vast.ai server with 8x A100 GPUs, each with 80 GB of VRAM, and about 1814 GB of RAM (as reported by Vast):

root@C.10527886:~$ uname -r
5.15.0-1050-azure
root@C.10527886:~$ cat /etc/*-release
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=22.04
DISTRIB_CODENAME=jammy
DISTRIB_DESCRIPTION="Ubuntu 22.04.3 LTS"
PRETTY_NAME="Ubuntu 22.04.3 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.3 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy

Reproduction

I don’t have an easy way to reproduce the problem. I started fine-tuning Mixtral using SFTTrainer, QLoRA, and DeepSpeed ZeRO-3 with 8x A100s and this bug occurred in the 6th epoch (perhaps it’s some sort of race condition):

{'loss': 1.8568, 'grad_norm': 0.035528212785720825, 'learning_rate': 4e-05, 'epoch': 0.0}
{'loss': 1.8158, 'grad_norm': 0.029280483722686768, 'learning_rate': 8e-05, 'epoch': 0.01}
{'loss': 1.8727, 'grad_norm': 0.030231980606913567, 'learning_rate': 0.00012, 'epoch': 0.01}
{'loss': 1.8823, 'grad_norm': 0.028545115143060684, 'learning_rate': 0.00016, 'epoch': 0.02}
{'loss': 1.8474, 'grad_norm': 0.026413191109895706, 'learning_rate': 0.0002, 'epoch': 0.02}
{'loss': 1.8512, 'grad_norm': 0.024209117516875267, 'learning_rate': 0.000199163179916318, 'epoch': 0.02}
  2%|██▋                                                                                                           | 6/244 [15:20<10:16:21, 155.39s/it]
Traceback (most recent call last):
  File "/root/train_bitsandbytes.py", line 117, in <module>
    trainer_stats = trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2118, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3036, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3059, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 825, in forward
    return model_forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 813, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
    return self.base_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1360, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1360, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 1217, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File "/opt/conda/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 482, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 261, in forward
    outputs = run_function(*args)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 947, in forward
    hidden_states, router_logits = self.block_sparse_moe(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 877, in forward
    current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 803, in forward
    current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 468, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 574, in matmul_4bit
    out = F.gemv_4bit(A, B.t(), out, state=quant_state)
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1994, in gemv_4bit
    is_on_gpu([B, A, out, absmax, state.code])
  File "/opt/conda/lib/python3.10/site-packages/bitsandbytes/functional.py", line 432, in is_on_gpu
    raise TypeError(
TypeError: Input tensors need to be on the same GPU, but found the following tensor and device combinations:
 [(torch.Size([29360128, 1]), device(type='cuda', index=6)), (torch.Size([1, 4096]), device(type='cuda', index=6)), (torch.Size([1, 14336]), device(type='cuda', index=6)), (torch.Size([917504]), device(type='cuda', index=6)), (torch.Size([16]), device(type='cuda', index=0))]

I’m not 100% sure which part of the training stack contains the bug here, but I very briefly started debugging this and found some code in bitsandbytes that looks suspicious:

# the quant state got lost when the parameter got converted. This happens for example for fsdp

Expected behavior

Training should not throw this error.

Oh, and I had gradient checkpointing enabled, as you can see from the stack trace. I already found one bug with using both gradient checkpointing and LoRA that I had to work around:
huggingface/peft#137 (comment)

so perhaps this is not a well-tested combination.