pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.

Home Page:https://pytorch.org/examples

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DDP training question

Henryplay opened this issue · comments

commented

Hi, I'm using the tutorial https://github.com/pytorch/tutorials/blob/master/intermediate_source/ddp_tutorial.rst for DDP train,using 4 gpus in myself code, reference Basic Use Case. But when I finished the modification, it was stuck during run the demo,meanwhile,video memory has been occupied.Could you help me?

commented

and my code is here

from math import gamma
import os
import torch
import argparse
from tqdm import tqdm
from utils.scheduler import GradualWarmupScheduler
from modeling.model import CNN
from modeling.loss import CTCLoss
from utils.dataset import CharDict, LoadData, ImageTransform
from utils.utils import paser_config, edit_distance_score, setup_logger
from torch.utils.data import DataLoader

import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

class Trainer:

    def __init__(self, config_file):
        self.configs = paser_config(config_file)
        # os.environ['CUDA_VISIBLE_DEVICES'] = self.configs['trainer']['gpus']
        self.build_dataloader()
        self.build_model()
        self.start_epoch = 0
        self.max_epochs = self.configs['trainer']['epochs']
        self.save_dir = os.path.join(self.configs['trainer']['output_dir'], self.configs['name'])
        if not os.path.exists(self.save_dir) : os.makedirs(self.save_dir)
        log_file_mode = 'a' if self.configs['trainer']["resume_ckpt"] else 'w'
        self.logger = setup_logger(log_file_path=os.path.join(self.save_dir, 'train.log'), log_file_mode=log_file_mode)
        self.checkpoint = {
            'epoch': 0,
            'history_acc': [],
            'history_eds': [],
            'model': {},
            'optimizer': {},
            'lr_scheduler': {},
            'configs': self.configs
        }
        if self.configs['trainer']["finetune_ckpt"]:
            self.model.load_state_dict(torch.load(self.configs['trainer']["finetune_ckpt"])['model'], False)
            #ckpt = torch.load(self.configs['trainer']["finetune_ckpt"])['model']
            #self.model.load_state_dict({k: v for k, v in ckpt.items() if 'fc' not in k},False)
        elif self.configs['trainer']["resume_ckpt"]:
            self.checkpoint = torch.load(self.configs['trainer']["resume_ckpt"])
            self.model.load_state_dict(self.checkpoint['model'])
            self.optimizer.load_state_dict(self.checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(self.checkpoint['lr_scheduler'])
            self.checkpoint['model'].clear()
            self.checkpoint['optimizer'].clear()
            self.checkpoint['lr_scheduler'].clear()
            self.start_epoch = self.checkpoint['epoch'] + 1
        # warp dp-model
        # self.model = torch.nn.DataParallel(self.model)
    def setup(self,rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    def cleanup(self):
        dist.destroy_process_group()
    def train(self,rank,world_size):
        self.setup(rank,world_size)
        self.model = self.model.to(rank)
        self.model = DDP(self.model, device_ids=[rank])
        for epoch in range(self.start_epoch, self.max_epochs):
            self.model.train()
            self.checkpoint['epoch'] = epoch
            for i, datas in enumerate(self.train_dataloader):
                img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
                img = img.to(rank)
                preds = self.model(img)
                loss = self.criterion(preds, targets.to(rank), target_lens.to(rank))
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                # log info
                if i%10 == 0:
                    batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
                    msg = "Epoch: %d/%d, " % (epoch, self.max_epochs) + \
                          "Batch: %d/%d, "%(i, len(self.train_dataloader)) + \
                          "Lr: %.6f, " %  self.scheduler_warmup.get_last_lr()[0] + \
                          "Loss: %.3f, " % loss.item() + \
                          "Acc: %.3f, EDS: %.3f" % (batch_acc, batch_eds)
                    self.logger.info(msg)
            self.scheduler_warmup.step()
            self.cleanup()
            self.eval()

    @torch.no_grad()
    def eval(self):
        self.model.eval()
        nbatch = len(self.test_dataloader)
        acc, eds = 0, 0
        for datas in tqdm(self.test_dataloader, desc="Testing..."):
            img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
            preds = self.model(img.cuda())
            batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
            acc += batch_acc
            eds += batch_eds
        mean_acc = acc / nbatch
        mean_eds = eds / nbatch

        self.save_model(mean_acc, mean_eds)
        return mean_acc, mean_eds

    def metrics(self, preds, targets, target_lens):
        """WARNING:
            This function will consume a lot of time. Don't use it frequently.
        """
        bs = preds.size(0)
        preds_prob,  preds_idx = preds.permute(0,2,1).detach().softmax(dim=2).max(2)
        decode_idx, decode_prob,_ = self.chardict.ctc_decode(preds_idx.cpu().numpy(), preds_prob.cpu().numpy())
        preds_texts = [self.chardict.idx2text(i, reserve_char='\a') for i in decode_idx]
        target_texts = [self.chardict.idx2text(t[:l], reserve_char='') for t, l in zip(targets, target_lens)]
        ed_score = 0.0
        n_correct = 0
        for s1, s2 in zip(preds_texts, target_texts):
            ed_score += edit_distance_score(s1, s2)
            n_correct += (s1 == s2)
        ed_score /= bs
        batch_acc = n_correct / bs
        return batch_acc, ed_score

    def save_model(self, cur_acc, cur_eds):
        best_acc_path = os.path.join(self.save_dir, "model_best_acc.pth")
        best_eds_path = os.path.join(self.save_dir, "model_best_eds.pth")
        model_last_path = os.path.join(self.save_dir, "model_last.pth")
        self.checkpoint['history_acc'].append(cur_acc)
        self.checkpoint['history_eds'].append(cur_eds)
        self.checkpoint['model'] = self.model.module.state_dict()
        self.checkpoint['optimizer'] = self.optimizer.state_dict()
        self.checkpoint['lr_scheduler'] = self.lr_scheduler.state_dict()

        torch.save(self.checkpoint, model_last_path)
        self.logger.info("Current acc: %.3f, eds: %.3f" % (cur_acc, cur_eds))
        self.logger.info("Save current epoch model to: %s" % model_last_path)
        best_acc = max(self.checkpoint['history_acc'])
        best_eds = max(self.checkpoint['history_eds'])
        if cur_acc >= best_acc:
            torch.save(self.checkpoint, best_acc_path)
            self.logger.info("Best acc: %.3f", cur_acc)
            self.logger.info("Save best Acc model to: %s" % best_acc_path)
        if cur_eds >= best_eds:
            torch.save(self.checkpoint, best_eds_path)
            self.logger.info("Best eds: %.3f", cur_eds)
            self.logger.info("Save best EDS model to: %s" % best_eds_path)

        # release
        self.checkpoint['model'].clear()
        self.checkpoint['optimizer'].clear()
        self.checkpoint['lr_scheduler'].clear()

    def build_model(self):
        in_dim = 1 if self.configs['dataset']['img_mode'] == 'gray' else 3
        out_dim = self.configs['dataset']['ncls']
        self.model = CNN(in_dim, out_dim)
        self.optimizer = getattr(torch.optim, self.configs['optimizer']['type'])(
            self.model.parameters(), **self.configs['optimizer']['args'])
        #set lr_decay
        lr_scheduler_type = self.configs['lr_scheduler']['type']
        if lr_scheduler_type == "StepLR":
            self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.configs['lr_scheduler']['type'])(
                self.optimizer, **self.configs['lr_scheduler']['args'])
        else:
            self.lr_scheduler = getattr(torch.optim.lr_scheduler,self.configs['lr_scheduler']['type'])(
                self.optimizer,5
            )
        self.criterion = CTCLoss()

    def build_dataloader(self):
        self.chardict = CharDict(
            self.configs['dataset']['dict'], self.configs['dataset']['ncls'])
        imtrans = ImageTransform(
            self.configs['dataset']['img_mode'], self.configs['dataset']['img_size'])
        trainset = LoadData(
            self.configs['dataset']['trainset'], self.chardict, imtrans)
        self.train_dataloader = DataLoader(
            trainset, self.configs['dataset']['batch_size'], shuffle=True, collate_fn=trainset.collate_fn, num_workers=16)
        testset = LoadData(
            self.configs['dataset']['testset'], self.chardict, imtrans)
        self.test_dataloader = DataLoader(
            testset, self.configs['dataset']['batch_size'], shuffle=False, collate_fn=trainset.collate_fn, num_workers=16)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file', default='config/pycrnn.yaml', type=str)
    args = parser.parse_args()
    trainer = Trainer(args.config_file)
    world_size = 4
    mp.spawn(trainer.train,
            args=(world_size, ),
            nprocs = world_size,
            join=True)

Hi, do you have a problem with the application getting stuck after starting multiple nodes? On my side, too, running the official multi-node example would get stuck