BaguaSys / bagua

Bagua Speeds up PyTorch

Home Page:https://tutorials-8ro.pages.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

I use bagua with the phenomenon as follows ( bagua.broadcast(ps, 0, comm=comm) )

lixiangMindSpore opened this issue · comments

Describe the bug
A clear and concise description of what the bug is.
image

Environment

  • Your operating system and version:Ubuntu18.04
  • Your python version:3.8.12
  • Your PyTorch version:11.0
  • How did you install python (e.g. apt or pyenv)? Did you use a virtualenv?:conda create -n torch17 python=3.8
  • Have you tried using latest bagua master (python3 -m pip install --pre bagua)?:0.8.1.post1

Reproducing

Please provide a minimal working example. This means the runnable code.

Please also write what exact commands are required to reproduce your results.

Additional context
Add any other context about the problem here.

comm = bagua.communication._get_default_group().get_global_communicator()  ###
 
# Broadcast init parameters
for ps in backbone.parameters():
    bagua.broadcast(ps, 0, comm=comm)     ###

This error is generally that the model is not on the specified GPU device. You can check whether the GPU where the model is located is equal to bagua.get_local_rank().

If there is still not work, please provide the minimal bug producing example script.

import os
import argparse
import time
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from backbones.model_irse import IR_SE_50, IR_SE_101
from config import config as cfg
from utils import *
from dataset import MXFaceDataset, DataLoaderX

from dataset import RecgDataset_mask as RecgDataset # 口罩增强~

from partial_classifier import DistSampleClassifier
from partial_loss import MarginSoftmax
from sgd import SGD
from torchsummary import summary
import bagua.torch_api as bagua ###
from bagua.torch_api.algorithms import gradient_allreduce ###
from common_utils import find_free_port

torch.backends.cudnn.benchmark = True

def should_distribute():
return dist.is_available() and world_size >= 1

def is_distributed():
return dist.is_available() and dist.is_initialized()

def _init_bagua_env(rank, env):
# init bagua distributed process group
torch.cuda.set_device(bagua.get_local_rank())
bagua.init_process_group()

def main(local_rank, rank, world_size, cfg):
# dataloader
print('loading data...')
comm = bagua.communication._get_default_group().get_global_communicator() ###
trainset = RecgDataset(cfg)

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
train_loader = DataLoaderX(local_rank=local_rank,
                           dataset=trainset,
                           batch_size=cfg.batch_size,
                           sampler=train_sampler,
                           num_workers=0,
                           pin_memory=True,
                           drop_last=True)

# model
print('loading model...')
backbone = IR_SE_50(cfg.input_size).to(local_rank)
# backbone = IR_SE_101(cfg.input_size)
# Memory classifer
dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank,
                                             world_size=world_size)
# Margin softmax
margin_softmax = MarginSoftmax(s=64.0, m=0.4)

# Optimizer for backbone and classifer
# optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
#                lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
# for ps in backbone.parameters():
#    dist.broadcast(ps, 0)

backbone_path = os.path.join(cfg.model_save + 'backbone/')
head_path = os.path.join(cfg.model_save + 'head/')
log_path = os.path.join(cfg.log_save + 'shows/')

cfg.model_resume = cfg.model_save
backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
head_resume = os.path.join(cfg.model_resume + 'head/')

# if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
if cfg.resume and os.path.isdir(backbone_resume):
    print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
    backbone_list = os.listdir(backbone_resume)
    if backbone_list:
        # pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
        # tar_flag = max(pre_flags)
        # print(tar_flag)
        # backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        # backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
        print(backbone_resume + '/' + backbone_list[0])
        backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        backbone.load_state_dict(backbone_ckpt['backbone'])
        print('load backbone ~')
        # optimizer.load_state_dict(backbone_ckpt['optimizer'])
        # print(optimizer.param_groups[0]['lr'])
        # start_epoch = backbone_ckpt['epoch'] + 1
        start_epoch = 1
        fg = 0
        if rank == 0 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 1 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 2 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 3 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])

else:
    start_epoch = 0
    print("Train from Scratch")
print("=" * 60)

backbone = backbone.to(local_rank)

optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
                lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
for ps in backbone.parameters():
    bagua.broadcast(ps, 0, comm)     ###

# bagua
# backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
algorithm = gradient_allreduce.GradientAllReduceAlgorithm()
backbone = backbone.with_bagua([optimizer], algorithm)

# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                              lr_lambda=cfg.lr_func)
os.makedirs(log_path, exist_ok=True)
if local_rank == 0:
    writer = SummaryWriter(log_dir=log_path)

print('trainning...')
global_step = 0
n_epochs = cfg.num_epoch
# NUM_EPOCH_WARM_UP = n_epochs // 25
NUM_EPOCH_WARM_UP = 5
NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
backbone.train()
for epoch in range(start_epoch, n_epochs):
    train_sampler.set_epoch(epoch)

    for step, (img, label) in enumerate(train_loader):

        if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (
                global_step + 1 <= NUM_BATCH_WARM_UP) and 1:  # adjust LR for each training batch during warm up
            warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)

        total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
        # print('total_label:', total_label.shape)
        # print('norm_weight:', norm_weight.shape)
        features = backbone(img)  # feature 在内部归一化了

        # Features all-gather
        total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
        dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
        total_features.requires_grad = True

        # Calculate logits
        # print('&' * 60)
        # print('total_features:', total_features.shape)
        # print('norm_weight:', norm_weight.shape)
        # print('total_label:', total_label.shape)
        logits = dist_sample_classifer(total_features, norm_weight)  # cos =
        # print('logits1:', logits.shape)
        # print('logits:', logits.shape)
        # print('&' * 60)
        logits = margin_softmax(logits, total_label)
        # print('logits2:', logits.shape)
        # total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
        # dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
        #                 logits.data)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            # print('max_fc:', max_fc.shape)
            ###dist.all_reduce(max_fc, dist.ReduceOp.MAX)
            recv_max_fc = torch.zeros_like(max_fc)
            bagua.allreduce(max_fc, recv_max_fc, bagua.ReduceOp.MAX, comm=comm)   ###
            # print('#'*10, max_fc)
            # Calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - recv_max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            recv_logits_sum_exp = torch.zeros_like(logits_sum_exp)
            ###dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
            bagua.allreduce(logits_sum_exp, recv_logits_sum_exp, bagua.ReduceOp.SUM, comm=comm)  ###

            # Calculate prob
            logits_exp.div_(logits_sum_exp)

            # Get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # Calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            recv_loss = torch.zeros_like(loss)
            ###dist.all_reduce(loss, dist.ReduceOp.SUM)
            bagua.allreduce(loss, recv_loss, bagua.ReduceOp.SUM, comm=comm)  ###
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # Calculate grad
            grad[index] -= one_hot
            grad.div_(features.size()[0])

        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad = torch.zeros_like(features)

        # Feature gradient all-reduce
        ###dist.reduce_scatter(
            ###x_grad, list(total_features.grad.chunk(world_size, dim=0)))
        bagua.reduce_scatter(list(total_features.grad.chunk(world_size, dim=0)), x_grad, comm=comm)  ###
        x_grad.mul_(world_size)
        # Backward backbone
        features.backward(x_grad)
        optimizer.step()

        # Update classifer
        dist_sample_classifer.update()
        optimizer.zero_grad()

        tm = time.asctime().split()[-2]
        if rank == 0 and global_step % cfg.disp_freq == 0:
            writer.add_scalar('loss', loss_v, global_step)
            print('\nEpoch:{}/{} Batch:{}/{}\t'
                  'Loss:{loss:.4f}\t'
                  'lr: {lr:.4f}\t'
                  'TimeNow:{TimeNow}'.format(
                epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
                loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))

        global_step += 1
    scheduler.step()

    if rank == 0:
        os.makedirs(backbone_path, exist_ok=True)
        state = {'backbone': backbone.module.state_dict(),
                 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        torch.save(state,
                   backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))

    os.makedirs(head_path, exist_ok=True)
    if rank == 0:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch + 1, get_time())))
    if rank == 1:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch + 1, get_time())))
    if rank == 2:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch + 1, get_time())))
    if rank == 3:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state,
                   os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch + 1, get_time())))

###dist.destroy_process_group()

if name == "main":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--local_rank', type=int, default=1, help='local_rank')
args = parser.parse_args()
os.environ['NCCL_DEBUG'] = 'INFO'
world_size = int(os.environ.get('WORLD_SIZE', 1))
print("lixiang")

env = {
    "WORLD_SIZE": str(world_size),
    "LOCAL_WORLD_SIZE": str(world_size),
    "MASTER_ADDR": "127.0.0.1",
    "MASTER_PORT": str(find_free_port(8000, 8100)),
    "BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
}
print("lixiang2")
_init_bagua_env(args.local_rank, env)
rank = bagua.get_rank()
print(rank)

main(args.local_rank, rank, world_size, cfg)

This error is generally that the model is not on the specified GPU device. You can check whether the GPU where the model is located is equal to bagua.get_local_rank().

If there is still not work, please provide the minimal bug producing example script.

if I use DDP, it will be OK
"""
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
import os
import argparse
import time
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from backbones.model_irse import IR_SE_50, IR_SE_101
from config import config as cfg
from utils import *
from dataset import MXFaceDataset, DataLoaderX

from dataset import RecgDataset_mask as RecgDataset # 口罩增强~

from partial_classifier import DistSampleClassifier
from partial_loss import MarginSoftmax
from sgd import SGD
from torchsummary import summary

torch.backends.cudnn.benchmark = True

def should_distribute():
return dist.is_available() and world_size >= 1

def is_distributed():
return dist.is_available() and dist.is_initialized()

def main(local_rank, rank, world_size, cfg):
# dataloader
print('loading data...')
trainset = RecgDataset(cfg)

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
train_loader = DataLoaderX(local_rank=local_rank,
                           dataset=trainset,
                           batch_size=cfg.batch_size,
                           sampler=train_sampler,
                           num_workers=0,
                           pin_memory=True,
                           drop_last=True)

# model
print('loading model...')
backbone = IR_SE_50(cfg.input_size).to(local_rank)

backbone = IR_SE_101(cfg.input_size)

# Memory classifer
dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank, world_size=world_size)
# Margin softmax
margin_softmax = MarginSoftmax(s=64.0, m=0.4)

# Optimizer for backbone and classifer
# optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
#                lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
# for ps in backbone.parameters():
#    dist.broadcast(ps, 0)

backbone_path = os.path.join(cfg.model_save + 'backbone/')
head_path = os.path.join(cfg.model_save + 'head/')
log_path = os.path.join(cfg.log_save + 'shows/')

cfg.model_resume = cfg.model_save
backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
head_resume = os.path.join(cfg.model_resume + 'head/')

# if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
if cfg.resume and os.path.isdir(backbone_resume):
    print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
    backbone_list = os.listdir(backbone_resume)
    if backbone_list:
        # pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
        # tar_flag = max(pre_flags)
        # print(tar_flag)
        # backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        #backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
        print(backbone_resume + '/' + backbone_list[0])
        backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
        backbone.load_state_dict(backbone_ckpt['backbone'])
        print('load backbone ~')
        #optimizer.load_state_dict(backbone_ckpt['optimizer'])
        #print(optimizer.param_groups[0]['lr'])
        #start_epoch = backbone_ckpt['epoch'] + 1
        start_epoch = 1
        fg = 0
        if rank == 0 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 1 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 2 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])
        if rank == 3 and fg:
            head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
            dist_sample_classifer.load_state_dict(head_ckpt['head'])

else:
    start_epoch = 0
    print("Train from Scratch")
print("=" * 60) 

backbone = backbone.to(local_rank)

optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
                lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)

# Broadcast init parameters
for ps in backbone.parameters():
    dist.broadcast(ps, 0)

DDP

backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)

backbone = torch.nn.parallel.DistributedDataParallel(
    module=backbone, broadcast_buffers=False, device_ids=[rank])

# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                              lr_lambda=cfg.lr_func)
os.makedirs(log_path, exist_ok=True)
if local_rank == 0:
    writer = SummaryWriter(log_dir=log_path)

print('trainning...')
global_step = 0
n_epochs = cfg.num_epoch
# NUM_EPOCH_WARM_UP = n_epochs // 25
NUM_EPOCH_WARM_UP = 5
NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
backbone.train()
for epoch in range(start_epoch, n_epochs):
    train_sampler.set_epoch(epoch)
    print('lixiang0000000')
    print(len(train_loader))

    for step, (img, label) in enumerate(train_loader):
        print('lixiang1111111111111111')

        if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (global_step + 1 <= NUM_BATCH_WARM_UP) and 1: # adjust LR for each training batch during warm up
            warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)

        total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
        # print('total_label:', total_label.shape)
        # print('norm_weight:', norm_weight.shape)
        features = backbone(img)    # feature 在内部归一化了
         
        # Features all-gather
        total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
        dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
        total_features.requires_grad = True

        # Calculate logits
        # print('&' * 60)
        # print('total_features:', total_features.shape)
        # print('norm_weight:', norm_weight.shape)
        # print('total_label:', total_label.shape)
        logits = dist_sample_classifer(total_features, norm_weight)  # cos =
        # print('logits1:', logits.shape)
        # print('logits:', logits.shape)
        # print('&' * 60)
        logits = margin_softmax(logits, total_label)
        # print('logits2:', logits.shape)
        # total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
        # dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
        #                 logits.data)

        with torch.no_grad():
            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
            #print('max_fc:', max_fc.shape)
            dist.all_reduce(max_fc, dist.ReduceOp.MAX)
            #print('#'*10, max_fc)
            # Calculate exp(logits) and all-reduce
            logits_exp = torch.exp(logits - max_fc)
            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)

            # Calculate prob
            logits_exp.div_(logits_sum_exp)

            # Get one-hot
            grad = logits_exp
            index = torch.where(total_label != -1)[0]
            one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
            one_hot.scatter_(1, total_label[index, None], 1)

            # Calculate loss
            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
            loss[index] = grad[index].gather(1, total_label[index, None])
            dist.all_reduce(loss, dist.ReduceOp.SUM)
            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)

            # Calculate grad
            grad[index] -= one_hot
            grad.div_(features.size()[0])
        
        logits.backward(grad)
        if total_features.grad is not None:
            total_features.grad.detach_()
        x_grad = torch.zeros_like(features)

        # Feature gradient all-reduce
        dist.reduce_scatter(
            x_grad, list(total_features.grad.chunk(world_size, dim=0)))
        x_grad.mul_(world_size)
        # Backward backbone
        features.backward(x_grad)
        optimizer.step()

        # Update classifer
        dist_sample_classifer.update()
        optimizer.zero_grad()

        tm = time.asctime().split()[-2]
        print('neighbour:',rank)
        if rank == 0 and global_step % cfg.disp_freq == 0:
            writer.add_scalar('loss', loss_v, global_step)
            print('\nEpoch:{}/{} Batch:{}/{}\t'
                  'Loss:{loss:.4f}\t'
                  'lr: {lr:.4f}\t'
                  'TimeNow:{TimeNow}'.format(
                epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
                loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))

        global_step += 1
    scheduler.step()

    if rank == 0:
        os.makedirs(backbone_path, exist_ok=True)
        state = {'backbone': backbone.module.state_dict(),
                 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        
        torch.save(state,
                   backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))

    os.makedirs(head_path, exist_ok=True)
    if rank == 0:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch+1, get_time())))
    if rank == 1:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch+1, get_time())))
    if rank == 2:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch+1, get_time())))
    if rank == 3:
        state = {'head': dist_sample_classifer.state_dict()}
        torch.save(state, os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch+1, get_time())))

    
dist.destroy_process_group()

if name == "main":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--local_rank', type=int, default=1, help='local_rank')
args = parser.parse_args()
os.environ['NCCL_DEBUG'] = 'INFO'
world_size = int(os.environ.get('WORLD_SIZE', 1))
if should_distribute():
print('Using distributed PyTorch with {} backend'.format(dist.Backend.NCCL))
dist.init_process_group(backend='nccl',
init_method='env://',
rank=args.local_rank,
world_size=world_size)
rank = dist.get_rank()
else:
rank = torch.cuda.device_count()

torch.cuda.set_device(args.local_rank)

main(args.local_rank, rank, world_size, cfg)

@lixiangMindSpore Bagua relaxes the restrictions on user scripting. We don't pass --local_rank to the user process, so args.local_rank won't work as you expect, you should use bagua.get_local_rank() replace it.

As I mentioned it in the previous comment.

commented

Feel free to reopen the issue if it does not work as expected :)