Pytorch 2.0 DDP. example only works on 1 GPU instead of 2
t1masavin opened this issue · comments
Example based on multigpu_torchrun
But when viewing in nvidia-smi, I see that only 1 card is loaded:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | Off |
| 0% 43C P2 67W / 450W | 21944MiB / 24564MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... On | 00000000:09:00.0 Off | Off |
| 0% 40C P8 31W / 480W | 67MiB / 24564MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1363 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 961152 C ...3/envs/tagging/bin/python 21936MiB |
| 1 N/A N/A 1363 G /usr/lib/xorg/Xorg 46MiB |
| 1 N/A N/A 1601 G /usr/bin/gnome-shell 16MiB |
+-----------------------------------------------------------------------------+
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import GTSRB
from torchvision import transforms as T
from torch import nn
from torchvision import models as M
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_setup():
init_process_group(backend="nccl")
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
) -> None:
self.gpu_id = int(os.environ["LOCAL_RANK"])
self.model = model.to(self.gpu_id)
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
self.epochs_run = 0
self.snapshot_path = snapshot_path
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)
self.model = DDP(self.model, device_ids=[self.gpu_id])
def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.gpu_id}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss.backward()
self.optimizer.step()
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
self._run_batch(source, targets)
def _save_snapshot(self, epoch):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": epoch,
}
torch.save(snapshot, self.snapshot_path)
print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")
def train(self, max_epochs: int):
for epoch in range(self.epochs_run, max_epochs):
self._run_epoch(epoch)
if epoch % self.save_every == 0:
self._save_snapshot(epoch)
def get_resnet(NUM_CLASSES):
model = M.resnet18(weights=M.ResNet18_Weights.DEFAULT)
in_feat = model.fc.in_features
model.fc = nn.Linear(in_features=in_feat, out_features=NUM_CLASSES)
return model
def load_train_objs():
transform = T.Compose([
T.Resize((50,50)),
T.ToTensor(),
])
train_set = GTSRB('/home/timssh/ML/TAGGING/NEW_TRAIN/data/', transform=transform, download=True)
model = get_resnet(43)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int):
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(dataset)
)
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--total_epochs', default=10, type=int, help='Total epochs to train the model')
parser.add_argument('--save_every', default=1, type=int, help='How often to save a snapshot')
parser.add_argument('--batch_size', default=15000, type=int, help='Input batch size on each device (default: 32)')
args = parser.parse_args()
main(args.save_every, args.total_epochs, args.batch_size)