DDP training question
Henryplay opened this issue · comments
Henry commented
Hi, I'm using the tutorial https://github.com/pytorch/tutorials/blob/master/intermediate_source/ddp_tutorial.rst for DDP train,using 4 gpus in myself code, reference Basic Use Case. But when I finished the modification, it was stuck during run the demo,meanwhile,video memory has been occupied.Could you help me?
Henry commented
and my code is here
from math import gamma
import os
import torch
import argparse
from tqdm import tqdm
from utils.scheduler import GradualWarmupScheduler
from modeling.model import CNN
from modeling.loss import CTCLoss
from utils.dataset import CharDict, LoadData, ImageTransform
from utils.utils import paser_config, edit_distance_score, setup_logger
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
class Trainer:
def __init__(self, config_file):
self.configs = paser_config(config_file)
# os.environ['CUDA_VISIBLE_DEVICES'] = self.configs['trainer']['gpus']
self.build_dataloader()
self.build_model()
self.start_epoch = 0
self.max_epochs = self.configs['trainer']['epochs']
self.save_dir = os.path.join(self.configs['trainer']['output_dir'], self.configs['name'])
if not os.path.exists(self.save_dir) : os.makedirs(self.save_dir)
log_file_mode = 'a' if self.configs['trainer']["resume_ckpt"] else 'w'
self.logger = setup_logger(log_file_path=os.path.join(self.save_dir, 'train.log'), log_file_mode=log_file_mode)
self.checkpoint = {
'epoch': 0,
'history_acc': [],
'history_eds': [],
'model': {},
'optimizer': {},
'lr_scheduler': {},
'configs': self.configs
}
if self.configs['trainer']["finetune_ckpt"]:
self.model.load_state_dict(torch.load(self.configs['trainer']["finetune_ckpt"])['model'], False)
#ckpt = torch.load(self.configs['trainer']["finetune_ckpt"])['model']
#self.model.load_state_dict({k: v for k, v in ckpt.items() if 'fc' not in k},False)
elif self.configs['trainer']["resume_ckpt"]:
self.checkpoint = torch.load(self.configs['trainer']["resume_ckpt"])
self.model.load_state_dict(self.checkpoint['model'])
self.optimizer.load_state_dict(self.checkpoint['optimizer'])
self.lr_scheduler.load_state_dict(self.checkpoint['lr_scheduler'])
self.checkpoint['model'].clear()
self.checkpoint['optimizer'].clear()
self.checkpoint['lr_scheduler'].clear()
self.start_epoch = self.checkpoint['epoch'] + 1
# warp dp-model
# self.model = torch.nn.DataParallel(self.model)
def setup(self,rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup(self):
dist.destroy_process_group()
def train(self,rank,world_size):
self.setup(rank,world_size)
self.model = self.model.to(rank)
self.model = DDP(self.model, device_ids=[rank])
for epoch in range(self.start_epoch, self.max_epochs):
self.model.train()
self.checkpoint['epoch'] = epoch
for i, datas in enumerate(self.train_dataloader):
img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
img = img.to(rank)
preds = self.model(img)
loss = self.criterion(preds, targets.to(rank), target_lens.to(rank))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
# log info
if i%10 == 0:
batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
msg = "Epoch: %d/%d, " % (epoch, self.max_epochs) + \
"Batch: %d/%d, "%(i, len(self.train_dataloader)) + \
"Lr: %.6f, " % self.scheduler_warmup.get_last_lr()[0] + \
"Loss: %.3f, " % loss.item() + \
"Acc: %.3f, EDS: %.3f" % (batch_acc, batch_eds)
self.logger.info(msg)
self.scheduler_warmup.step()
self.cleanup()
self.eval()
@torch.no_grad()
def eval(self):
self.model.eval()
nbatch = len(self.test_dataloader)
acc, eds = 0, 0
for datas in tqdm(self.test_dataloader, desc="Testing..."):
img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
preds = self.model(img.cuda())
batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
acc += batch_acc
eds += batch_eds
mean_acc = acc / nbatch
mean_eds = eds / nbatch
self.save_model(mean_acc, mean_eds)
return mean_acc, mean_eds
def metrics(self, preds, targets, target_lens):
"""WARNING:
This function will consume a lot of time. Don't use it frequently.
"""
bs = preds.size(0)
preds_prob, preds_idx = preds.permute(0,2,1).detach().softmax(dim=2).max(2)
decode_idx, decode_prob,_ = self.chardict.ctc_decode(preds_idx.cpu().numpy(), preds_prob.cpu().numpy())
preds_texts = [self.chardict.idx2text(i, reserve_char='\a') for i in decode_idx]
target_texts = [self.chardict.idx2text(t[:l], reserve_char='') for t, l in zip(targets, target_lens)]
ed_score = 0.0
n_correct = 0
for s1, s2 in zip(preds_texts, target_texts):
ed_score += edit_distance_score(s1, s2)
n_correct += (s1 == s2)
ed_score /= bs
batch_acc = n_correct / bs
return batch_acc, ed_score
def save_model(self, cur_acc, cur_eds):
best_acc_path = os.path.join(self.save_dir, "model_best_acc.pth")
best_eds_path = os.path.join(self.save_dir, "model_best_eds.pth")
model_last_path = os.path.join(self.save_dir, "model_last.pth")
self.checkpoint['history_acc'].append(cur_acc)
self.checkpoint['history_eds'].append(cur_eds)
self.checkpoint['model'] = self.model.module.state_dict()
self.checkpoint['optimizer'] = self.optimizer.state_dict()
self.checkpoint['lr_scheduler'] = self.lr_scheduler.state_dict()
torch.save(self.checkpoint, model_last_path)
self.logger.info("Current acc: %.3f, eds: %.3f" % (cur_acc, cur_eds))
self.logger.info("Save current epoch model to: %s" % model_last_path)
best_acc = max(self.checkpoint['history_acc'])
best_eds = max(self.checkpoint['history_eds'])
if cur_acc >= best_acc:
torch.save(self.checkpoint, best_acc_path)
self.logger.info("Best acc: %.3f", cur_acc)
self.logger.info("Save best Acc model to: %s" % best_acc_path)
if cur_eds >= best_eds:
torch.save(self.checkpoint, best_eds_path)
self.logger.info("Best eds: %.3f", cur_eds)
self.logger.info("Save best EDS model to: %s" % best_eds_path)
# release
self.checkpoint['model'].clear()
self.checkpoint['optimizer'].clear()
self.checkpoint['lr_scheduler'].clear()
def build_model(self):
in_dim = 1 if self.configs['dataset']['img_mode'] == 'gray' else 3
out_dim = self.configs['dataset']['ncls']
self.model = CNN(in_dim, out_dim)
self.optimizer = getattr(torch.optim, self.configs['optimizer']['type'])(
self.model.parameters(), **self.configs['optimizer']['args'])
#set lr_decay
lr_scheduler_type = self.configs['lr_scheduler']['type']
if lr_scheduler_type == "StepLR":
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.configs['lr_scheduler']['type'])(
self.optimizer, **self.configs['lr_scheduler']['args'])
else:
self.lr_scheduler = getattr(torch.optim.lr_scheduler,self.configs['lr_scheduler']['type'])(
self.optimizer,5
)
self.criterion = CTCLoss()
def build_dataloader(self):
self.chardict = CharDict(
self.configs['dataset']['dict'], self.configs['dataset']['ncls'])
imtrans = ImageTransform(
self.configs['dataset']['img_mode'], self.configs['dataset']['img_size'])
trainset = LoadData(
self.configs['dataset']['trainset'], self.chardict, imtrans)
self.train_dataloader = DataLoader(
trainset, self.configs['dataset']['batch_size'], shuffle=True, collate_fn=trainset.collate_fn, num_workers=16)
testset = LoadData(
self.configs['dataset']['testset'], self.chardict, imtrans)
self.test_dataloader = DataLoader(
testset, self.configs['dataset']['batch_size'], shuffle=False, collate_fn=trainset.collate_fn, num_workers=16)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config_file', default='config/pycrnn.yaml', type=str)
args = parser.parse_args()
trainer = Trainer(args.config_file)
world_size = 4
mp.spawn(trainer.train,
args=(world_size, ),
nprocs = world_size,
join=True)
AntyRia commented
Hi, do you have a problem with the application getting stuck after starting multiple nodes? On my side, too, running the official multi-node example would get stuck