BaguaSys / bagua

Bagua Speeds up PyTorch

Home Page:https://tutorials-8ro.pages.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

My process has been blocked,the screen (as follows) change nothing till 30 minutes

lixiangMindSpore opened this issue · comments

Describe the bug
A clear and concise description of what the bug is.

Environment

  • Your operating system and version:Ubuntu18.04
  • Your python version:3.8
  • Your PyTorch version:11.0
  • How did you install python (e.g. apt or pyenv)? Did you use a virtualenv?:
  • Have you tried using latest bagua master (python3 -m pip install git+https://github.com/BaguaSys/bagua.git -f https://repo.arrayfire.com/python/wheels/3.8.0/ )?:0.8.1.post1

Reproducing

Please provide a minimal working example. This means the runnable code.

Please also write what exact commands are required to reproduce your results.

Additional context
Add any other context about the problem here.

我的程序卡死了,就像界面所示那样
image

commented

Does it hang when --nproc_per_node=1 (which should be the same as running the training with a single card)?

Could you provide a minimal bug producing example script if it still hangs?

Does it hang when --nproc_per_node=1 (which should be the same as running the training with a single card)?
the phenominon is the same
Could you provide a minimal bug producing example script if it still hangs?

"""
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
import os
import argparse
import time
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from backbones.model_irse import IR_SE_50, IR_SE_101
from config import config as cfg
from utils import *
from dataset import MXFaceDataset, DataLoaderX

from dataset import RecgDataset_mask as RecgDataset  # 口罩增强~

from partial_classifier import DistSampleClassifier
from partial_loss import MarginSoftmax
from sgd import SGD
from torchsummary import summary
import bagua.torch_api as bagua  ###
from bagua.torch_api.algorithms import gradient_allreduce   ###
from common_utils import find_free_port

torch.backends.cudnn.benchmark = True


def should_distribute():
    return dist.is_available() and world_size >= 1


def is_distributed():
    return dist.is_available() and dist.is_initialized()

def _init_bagua_env(rank, env):
    # initialize subprocess env
    os.environ["WORLD_SIZE"] = env["WORLD_SIZE"]
    os.environ["LOCAL_WORLD_SIZE"] = env["LOCAL_WORLD_SIZE"]
    os.environ["MASTER_ADDR"] = env["MASTER_ADDR"]
    os.environ["MASTER_PORT"] = env["MASTER_PORT"]
    os.environ["BAGUA_SERVICE_PORT"] = env["BAGUA_SERVICE_PORT"]

    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)

    # init bagua distributed process group
    torch.cuda.set_device(rank)
    bagua.init_process_group()

def main(local_rank, rank, world_size, cfg):
    # dataloader
    print('loading data...')
    comm = bagua.communication._get_default_group().get_global_communicator()  ###
    trainset = RecgDataset(cfg)

    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
    train_loader = DataLoaderX(local_rank=local_rank,
                               dataset=trainset,
                               batch_size=cfg.batch_size,
                               sampler=train_sampler,
                               num_workers=0,
                               pin_memory=True,
                               drop_last=True)

    # model
    print('loading model...')
    backbone = IR_SE_50(cfg.input_size).to(local_rank)
    # backbone = IR_SE_101(cfg.input_size)
    # Memory classifer
    dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank,
                                                 world_size=world_size)
    # Margin softmax
    margin_softmax = MarginSoftmax(s=64.0, m=0.4)

    # Optimizer for backbone and classifer
    # optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
    #                lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

    # Broadcast init parameters
    # for ps in backbone.parameters():
    #    dist.broadcast(ps, 0)

    backbone_path = os.path.join(cfg.model_save + 'backbone/')
    head_path = os.path.join(cfg.model_save + 'head/')
    log_path = os.path.join(cfg.log_save + 'shows/')

    cfg.model_resume = cfg.model_save
    backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
    head_resume = os.path.join(cfg.model_resume + 'head/')

    # if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
    if cfg.resume and os.path.isdir(backbone_resume):
        print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
        backbone_list = os.listdir(backbone_resume)
        if backbone_list:
            # pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
            # tar_flag = max(pre_flags)
            # print(tar_flag)
            # backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
            # backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
            print(backbone_resume + '/' + backbone_list[0])
            backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
            backbone.load_state_dict(backbone_ckpt['backbone'])
            print('load backbone ~')
            # optimizer.load_state_dict(backbone_ckpt['optimizer'])
            # print(optimizer.param_groups[0]['lr'])
            # start_epoch = backbone_ckpt['epoch'] + 1
            start_epoch = 1
            fg = 0
            if rank == 0 and fg:
                head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
                dist_sample_classifer.load_state_dict(head_ckpt['head'])
            if rank == 1 and fg:
                head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
                dist_sample_classifer.load_state_dict(head_ckpt['head'])
            if rank == 2 and fg:
                head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
                dist_sample_classifer.load_state_dict(head_ckpt['head'])
            if rank == 3 and fg:
                head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
                dist_sample_classifer.load_state_dict(head_ckpt['head'])

    else:
        start_epoch = 0
        print("Train from Scratch")
    print("=" * 60)

    backbone = backbone.to(local_rank)

    optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
                    lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

    # Broadcast init parameters
    for ps in backbone.parameters():
        bagua.broadcast(ps, 0, comm=comm)     ###

    # bagua
    # backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
    algorithm = gradient_allreduce.GradientAllReduceAlgorithm()
    backbone = backbone.with_bagua([optimizer], algorithm)

    # Lr scheduler
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=cfg.lr_func)
    os.makedirs(log_path, exist_ok=True)
    if local_rank == 0:
        writer = SummaryWriter(log_dir=log_path)

    print('trainning...')
    global_step = 0
    n_epochs = cfg.num_epoch
    # NUM_EPOCH_WARM_UP = n_epochs // 25
    NUM_EPOCH_WARM_UP = 5
    NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
    backbone.train()
    for epoch in range(start_epoch, n_epochs):
        train_sampler.set_epoch(epoch)

        for step, (img, label) in enumerate(train_loader):

            if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (
                    global_step + 1 <= NUM_BATCH_WARM_UP) and 1:  # adjust LR for each training batch during warm up
                warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)

            total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
            # print('total_label:', total_label.shape)
            # print('norm_weight:', norm_weight.shape)
            features = backbone(img)  # feature 在内部归一化了

            # Features all-gather
            total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
            dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
            total_features.requires_grad = True

            # Calculate logits
            # print('&' * 60)
            # print('total_features:', total_features.shape)
            # print('norm_weight:', norm_weight.shape)
            # print('total_label:', total_label.shape)
            logits = dist_sample_classifer(total_features, norm_weight)  # cos =
            # print('logits1:', logits.shape)
            # print('logits:', logits.shape)
            # print('&' * 60)
            logits = margin_softmax(logits, total_label)
            # print('logits2:', logits.shape)
            # total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
            # dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
            #                 logits.data)

            with torch.no_grad():
                max_fc = torch.max(logits, dim=1, keepdim=True)[0]
                # print('max_fc:', max_fc.shape)
                ###dist.all_reduce(max_fc, dist.ReduceOp.MAX)
                recv_max_fc = torch.zeros_like(max_fc)
                bagua.allreduce(max_fc, recv_max_fc, bagua.ReduceOp.MAX, comm=comm)   ###
                # print('#'*10, max_fc)
                # Calculate exp(logits) and all-reduce
                logits_exp = torch.exp(logits - recv_max_fc)
                logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
                recv_logits_sum_exp = torch.zeros_like(logits_sum_exp)
                ###dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
                bagua.allreduce(logits_sum_exp, recv_logits_sum_exp, bagua.ReduceOp.SUM, comm=comm)  ###

                # Calculate prob
                logits_exp.div_(logits_sum_exp)

                # Get one-hot
                grad = logits_exp
                index = torch.where(total_label != -1)[0]
                one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
                one_hot.scatter_(1, total_label[index, None], 1)

                # Calculate loss
                loss = torch.zeros(grad.size()[0], 1, device=grad.device)
                loss[index] = grad[index].gather(1, total_label[index, None])
                recv_loss = torch.zeros_like(loss)
                ###dist.all_reduce(loss, dist.ReduceOp.SUM)
                bagua.allreduce(loss, recv_loss, bagua.ReduceOp.SUM, comm=comm)  ###
                loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

                # Calculate grad
                grad[index] -= one_hot
                grad.div_(features.size()[0])

            logits.backward(grad)
            if total_features.grad is not None:
                total_features.grad.detach_()
            x_grad = torch.zeros_like(features)

            # Feature gradient all-reduce
            ###dist.reduce_scatter(
                ###x_grad, list(total_features.grad.chunk(world_size, dim=0)))
            bagua.reduce_scatter(list(total_features.grad.chunk(world_size, dim=0)), x_grad, comm=comm)  ###
            x_grad.mul_(world_size)
            # Backward backbone
            features.backward(x_grad)
            optimizer.step()

            # Update classifer
            dist_sample_classifer.update()
            optimizer.zero_grad()

            tm = time.asctime().split()[-2]
            if rank == 0 and global_step % cfg.disp_freq == 0:
                writer.add_scalar('loss', loss_v, global_step)
                print('\nEpoch:{}/{} Batch:{}/{}\t'
                      'Loss:{loss:.4f}\t'
                      'lr: {lr:.4f}\t'
                      'TimeNow:{TimeNow}'.format(
                    epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
                    loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))

            global_step += 1
        scheduler.step()

        if rank == 0:
            os.makedirs(backbone_path, exist_ok=True)
            state = {'backbone': backbone.module.state_dict(),
                     'optimizer': optimizer.state_dict(), 'epoch': epoch}

            torch.save(state,
                       backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))

        os.makedirs(head_path, exist_ok=True)
        if rank == 0:
            state = {'head': dist_sample_classifer.state_dict()}
            torch.save(state,
                       os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch + 1, get_time())))
        if rank == 1:
            state = {'head': dist_sample_classifer.state_dict()}
            torch.save(state,
                       os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch + 1, get_time())))
        if rank == 2:
            state = {'head': dist_sample_classifer.state_dict()}
            torch.save(state,
                       os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch + 1, get_time())))
        if rank == 3:
            state = {'head': dist_sample_classifer.state_dict()}
            torch.save(state,
                       os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch + 1, get_time())))

    ###dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--local_rank', type=int, default=1, help='local_rank')
    args = parser.parse_args()
    os.environ['NCCL_DEBUG'] = 'INFO'
    world_size = int(os.environ.get('WORLD_SIZE', 1))

    env = {
        "WORLD_SIZE": str(world_size),
        "LOCAL_WORLD_SIZE": str(world_size),
        "MASTER_ADDR": "127.0.0.1",
        "MASTER_PORT": str(find_free_port(8000, 8100)),
        "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
    }
    _init_bagua_env(args.local_rank, env)
    rank = bagua.get_rank()

    main(args.local_rank, rank, world_size, cfg)

Maybe the problem is "bagua.init_process_group()"

@lixiangMindSpore Can you try to replace the function _init_bagua_env with

def _init_bagua_env(rank, env):
    # init bagua distributed process group
    torch.cuda.set_device(bagua.get_local_rank())
    bagua.init_process_group()

@lixiangMindSpore Can you try to replace the function _init_bagua_env with

def _init_bagua_env(rank, env):
    # init bagua distributed process group
    torch.cuda.set_device(bagua.get_local_rank())
    bagua.init_process_group()

My original code is as follows.You mean that I only need to delete the code about os.environ[] = env[]?
def _init_bagua_env_backup(rank, env):
# initialize subprocess env
os.environ["WORLD_SIZE"] = env["WORLD_SIZE"]
os.environ["LOCAL_WORLD_SIZE"] = env["LOCAL_WORLD_SIZE"]
os.environ["MASTER_ADDR"] = env["MASTER_ADDR"]
os.environ["MASTER_PORT"] = env["MASTER_PORT"]
os.environ["BAGUA_SERVICE_PORT"] = env["BAGUA_SERVICE_PORT"]

os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)

# init bagua distributed process group
torch.cuda.set_device(rank)
bagua.init_process_group()

Your method maybe efficient, but I want to know the reason of the phenomenon. Thank you so much!

@lixiangMindSpore Can you try to replace the function _init_bagua_env with

def _init_bagua_env(rank, env):
    # init bagua distributed process group
    torch.cuda.set_device(bagua.get_local_rank())
    bagua.init_process_group()

Your method maybe efficient, but I want to know the reason of the phenomenon.

The code changes are based on the following two reasons:

  1. Bagua relaxes the restrictions on user scripting, and will not pass --local_rank to the user process. Instead, bagua.get_local_rank() should be used to obtain it.
  2. We do not recommend setting environment variables manually, these variables have been imported through bagua launch tool.

For more writing recommendations, you can refer to these two scripts:
https://github.com/BaguaSys/bagua/blob/master/examples/benchmark/synthetic_benchmark.py
https://github.com/BaguaSys/bagua/blob/master/examples/communication_primitives/main.py