Lightning-AI / lit-llama

Implementation of the LLaMA language model based on nanoGPT. Supports flash attention, Int8 and GPTQ 4bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Full fine-tuning on Alpaca dataset with 4 L40s GPUs fails 8 hours into the training job with index_copy_

cabal-daniel opened this issue · comments

I'm trying to run a training job with 4 L40s against the Alpaca finetuning set against the default finetune/full.py configuration which is on the open-llama 7B model. This is my nvidia-smi

root@C.7407591:~/lit-llama$ nvidia-smi
Fri Nov  3 15:14:35 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L40                     On  | 00000000:81:00.0 Off |                    0 |
| N/A   31C    P8              33W / 300W |      4MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA L40                     On  | 00000000:A1:00.0 Off |                    0 |
| N/A   30C    P8              33W / 300W |      4MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA L40                     On  | 00000000:C1:00.0 Off |                    0 |
| N/A   30C    P8              35W / 300W |      4MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA L40                     On  | 00000000:E1:00.0 Off |                    0 |
| N/A   30C    P8              33W / 300W |      4MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

I ran the command python finetune/full.py with very small modifications. This is the git diff

root@C.7407591:~/lit-llama$ git diff
diff --git a/finetune/full.py b/finetune/full.py
index 9248e8d..4a6ba99 100644
--- a/finetune/full.py
+++ b/finetune/full.py
@@ -15,6 +15,7 @@ from lightning.fabric.strategies import FSDPStrategy
 import numpy as np
 import torch
 from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+from datetime import datetime
 
 # support running without installing as a package
 wd = Path(__file__).parent.parent.resolve()
@@ -31,7 +32,7 @@ instruction_tuning = True
 eval_interval = 1000
 save_interval = 1000
 eval_iters = 100
-log_interval = 100
+log_interval = 1
 devices = 4
 
 # Hyperparameters
@@ -137,7 +138,7 @@ def train(
 
         dt = time.time() - t0
         if iter_num % log_interval == 0:
-            fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
+            fabric.print(f"[{datetime.now()}] iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
 
 
 def generate_response(model, instruction):

But then 8 hours into the training process I got these errors:

[2023-11-03 08:04:57.214671] iter 7921: loss 1.0308, time: 3301.58ms
[2023-11-03 08:05:00.208211] iter 7922: loss 0.9911, time: 2988.20ms
[2023-11-03 08:05:03.182380] iter 7923: loss 0.7134, time: 2968.94ms
[2023-11-03 08:05:06.498092] iter 7924: loss 0.7830, time: 3310.13ms
Traceback (most recent call last):
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/finetune/full.py", line 86, in main
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    val_loss = validate(fabric, model, val_data)
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/generate.py", line 65, in generate
    logits = model(x, max_seq_length, input_pos)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 119, in forward
    output = self._forward_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward

and then a few more iterations on, it crashed with

[2023-11-03 08:09:08.894303] iter 7996: loss 0.9990, time: 3127.77ms
[2023-11-03 08:09:11.912121] iter 7997: loss 0.6333, time: 3011.81ms
[2023-11-03 08:09:14.982724] iter 7998: loss 0.8165, time: 3064.63ms
Validating ...
Traceback (most recent call last):
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
Traceback (most recent call last):
  File "/root/lit-llama/finetune//full.py", line 225, in <module>
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
Traceback (most recent call last):
    CLI(main)
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return _run_component(components, cfg_init)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return _run_component(components, cfg_init)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/finetune/full.py", line 86, in main
    return component(**cfg)
  File "/root/lit-llama/finetune//full.py", line 86, in main
    CLI(main)
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
  File "/root/lit-llama/finetune/full.py", line 131, in train
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune//full.py", line 131, in train
    val_loss = validate(fabric, model, val_data)
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return _run_component(components, cfg_init)
    val_loss = validate(fabric, model, val_data)
  File "/root/lit-llama/lit_llama/model.py", line 114, in forward
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return func(*args, **kwargs)
    return component(**cfg)
  File "/root/lit-llama/finetune//full.py", line 177, in validate
  File "/root/lit-llama/finetune/full.py", line 86, in main
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune//full.py", line 152, in generate_response
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
    val_loss = validate(fabric, model, val_data)
  File "/root/lit-llama/generate.py", line 65, in generate
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/generate.py", line 65, in generate
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    logits = model(x, max_seq_length, input_pos)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return func(*args, **kwargs)
    logits = model(x, max_seq_length, input_pos)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return forward_call(*args, **kwargs)
    return func(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
  File "/root/lit-llama/generate.py", line 65, in generate
    logits = model(x, max_seq_length, input_pos)
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 119, in forward
  File "/root/.venv/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 119, in forward
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    output = self._forward_module(*args, **kwargs)
    output = self._forward_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 119, in forward
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 164, in forward
    output = self._forward_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/root/.venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
    return fn(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    ret = function(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/lit-llama/lit_llama/model.py", line 163, in forward
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/root/lit-llama/lit_llama/model.py", line 114, in forward
    return forward_call(*args, **kwargs)
  File "/root/lit-llama/lit_llama/model.py", line 114, in forward
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/root/lit-llama/lit_llama/model.py", line 114, in forward
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/lit-llama/lit_llama/model.py", line 217, in forward
    return self._call_impl(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    k = cache_k.index_copy(2, input_pos, k)
RuntimeError: index_copy_(): self and source expected to have the same dtype, but got (self) Float and (source) BFloat16

Any ideas what went wrong?

Maybe it's the fact precision is set to bf16-mixed?

Ran it again with bf16-true and got this error instead

[2023-11-03 19:13:17.491399] iter 7998: loss 0.8175, time: 1408.26ms
Validating ...
Traceback (most recent call last):
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/finetune/full.py", line 86, in main
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    val_loss = validate(fabric, model, val_data)
Traceback (most recent call last):
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
    return func(*args, **kwargs)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
Traceback (most recent call last):
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return func(*args, **kwargs)
  File "/root/lit-llama/generate.py", line 83, in generate
    return _run_component(components, cfg_init)
    idx = idx.index_copy(0, input_pos, idx_next)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument source in method wrapper_CUDA_index_copy)
    return component(**cfg)
  File "/root/lit-llama/finetune/full.py", line 86, in main
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    val_loss = validate(fabric, model, val_data)
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return _run_component(components, cfg_init)
    return func(*args, **kwargs)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
  File "/root/lit-llama/finetune/full.py", line 177, in validate
    return component(**cfg)
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 86, in main
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
    val_loss = validate(fabric, model, val_data)
  File "/root/lit-llama/generate.py", line 83, in generate
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    idx = idx.index_copy(0, input_pos, idx_next)
    return func(*args, **kwargs)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:3! (when checking argument for argument source in method wrapper_CUDA_index_copy)
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/generate.py", line 83, in generate
    idx = idx.index_copy(0, input_pos, idx_next)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:2! (when checking argument for argument source in method wrapper_CUDA_index_copy)
Traceback (most recent call last):
  File "/root/lit-llama/finetune/full.py", line 225, in <module>
    CLI(main)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/.venv/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/finetune/full.py", line 86, in main
    train(fabric, model, optimizer, train_data, val_data, out_dir)
  File "/root/lit-llama/finetune/full.py", line 131, in train
    val_loss = validate(fabric, model, val_data)
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/finetune/full.py", line 177, in validate
    output = generate_response(model, instruction)
  File "/root/lit-llama/finetune/full.py", line 152, in generate_response
    output = generate(
  File "/root/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/lit-llama/generate.py", line 83, in generate
    idx = idx.index_copy(0, input_pos, idx_next)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:1! (when checking argument for argument source in method wrapper_CUDA_index_copy)