I use DDP with no problems, but I use bagua with problem as follows.
lixiangMindSpore opened this issue · comments
Describe the bug
A clear and concise description of what the bug is.
backbone = backbone.with_bagua([optimizer], algorithm)
if rank == 0:
os.makedirs(backbone_path, exist_ok=True)
state = {'backbone': backbone.module.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state,
backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))
Environment
- Your operating system and version: Ubuntu18.04
- Your python version:3.8.12
- Your PyTorch version:11.0
- How did you install python (e.g. apt or pyenv)? Did you use a virtualenv?: conda create -n torch python=3.8
- Have you tried using latest bagua master (
python3 -m pip install --pre bagua
)?: 0.8.1.post1
Reproducing
Please provide a minimal working example. This means the runnable code.
Please also write what exact commands are required to reproduce your results.
Additional context
Add any other context about the problem here.
import os
import argparse
import time
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from backbones.model_irse import IR_SE_50, IR_SE_101
from config import config as cfg
from utils import *
from dataset import MXFaceDataset, DataLoaderX
from dataset import RecgDataset_mask as RecgDataset # 口罩增强~
from partial_classifier import DistSampleClassifier
from partial_loss import MarginSoftmax
from sgd import SGD
from torchsummary import summary
import bagua.torch_api as bagua ###
from bagua.torch_api.algorithms import gradient_allreduce ###
from bagua.torch_api.communication import (
_get_default_group,
allreduce,
send,
recv,
allgather,
barrier,
) ###
from common_utils import find_free_port
torch.backends.cudnn.benchmark = True
def should_distribute():
return dist.is_available() and world_size >= 1
def is_distributed():
return dist.is_available() and dist.is_initialized()
def _init_bagua_env(rank, env):
# init bagua distributed process group
torch.cuda.set_device(bagua.get_local_rank())
bagua.init_process_group()
def _init_bagua_env_backup(rank, env):
# initialize subprocess env
print("rank:", rank)
os.environ["WORLD_SIZE"] = env["WORLD_SIZE"]
os.environ["LOCAL_WORLD_SIZE"] = env["LOCAL_WORLD_SIZE"]
os.environ["MASTER_ADDR"] = env["MASTER_ADDR"]
os.environ["MASTER_PORT"] = env["MASTER_PORT"]
os.environ["BAGUA_SERVICE_PORT"] = env["BAGUA_SERVICE_PORT"]
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
# init bagua distributed process group
torch.cuda.set_device(rank)
print("222222222222")
bagua.init_process_group()
print("33333333333")
def main(local_rank, rank, world_size, cfg):
# dataloader
print('loading data...')
comm = bagua.communication._get_default_group().get_global_communicator() ###
trainset = RecgDataset(cfg)
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, shuffle=True)
train_loader = DataLoaderX(local_rank=local_rank,
dataset=trainset,
batch_size=cfg.batch_size,
sampler=train_sampler,
num_workers=0,
pin_memory=True,
drop_last=True)
# model
print('loading model...')
backbone = IR_SE_50(cfg.input_size).to(local_rank)
# backbone = IR_SE_101(cfg.input_size)
# Memory classifer
dist_sample_classifer = DistSampleClassifier(trainset.classes, rank=rank, local_rank=local_rank,
world_size=world_size)
# Margin softmax
margin_softmax = MarginSoftmax(s=64.0, m=0.4)
# Optimizer for backbone and classifer
# optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
# lr=0.1, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)
# Broadcast init parameters
# for ps in backbone.parameters():
# dist.broadcast(ps, 0)
backbone_path = os.path.join(cfg.model_save + 'backbone/')
head_path = os.path.join(cfg.model_save + 'head/')
log_path = os.path.join(cfg.log_save + 'shows/')
cfg.model_resume = cfg.model_save
backbone_resume = os.path.join(cfg.model_resume + 'backbone/')
head_resume = os.path.join(cfg.model_resume + 'head/')
# if cfg.resume and os.path.isdir(backbone_resume) and os.path.isdir(head_resume):
if cfg.resume and os.path.isdir(backbone_resume):
print('resume~~~~~~~~~~~~~~~~~~~~~~~~~~')
backbone_list = os.listdir(backbone_resume)
if backbone_list:
# pre_flags = [eval(x.split('Epoch_')[1].split('_Time')[0]) for x in backbone_list]
# tar_flag = max(pre_flags)
# print(tar_flag)
# backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
# backbone_ckpt = torch.load(backbone_path + '/' + str(tar_flag) + '_backbone.pth')
print(backbone_resume + '/' + backbone_list[0])
backbone_ckpt = torch.load(backbone_resume + '/' + backbone_list[0], map_location=torch.device('cpu'))
backbone.load_state_dict(backbone_ckpt['backbone'])
print('load backbone ~')
# optimizer.load_state_dict(backbone_ckpt['optimizer'])
# print(optimizer.param_groups[0]['lr'])
# start_epoch = backbone_ckpt['epoch'] + 1
start_epoch = 1
fg = 0
if rank == 0 and fg:
head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head0.pth')
dist_sample_classifer.load_state_dict(head_ckpt['head'])
if rank == 1 and fg:
head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head1.pth')
dist_sample_classifer.load_state_dict(head_ckpt['head'])
if rank == 2 and fg:
head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head2.pth')
dist_sample_classifer.load_state_dict(head_ckpt['head'])
if rank == 3 and fg:
head_ckpt = torch.load(head_resume + '/' + 'Head_Arcface_Epoch_27_Time_2021-07-04-05-58' + '_head3.pth')
dist_sample_classifer.load_state_dict(head_ckpt['head'])
else:
start_epoch = 0
print("Train from Scratch")
print("=" * 60)
backbone = backbone.to(local_rank)
optimizer = SGD([{'params': backbone.parameters()}, {'params': dist_sample_classifer.parameters()}],
lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, rescale=world_size)
# Broadcast init parameters
for ps in backbone.parameters():
bagua.broadcast(ps, 0, comm) ###
# bagua
# backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
algorithm = gradient_allreduce.GradientAllReduceAlgorithm()
backbone = backbone.with_bagua([optimizer], algorithm)
# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lr_lambda=cfg.lr_func)
os.makedirs(log_path, exist_ok=True)
if local_rank == 0:
writer = SummaryWriter(log_dir=log_path)
print('trainning...')
global_step = 0
n_epochs = cfg.num_epoch
# NUM_EPOCH_WARM_UP = n_epochs // 25
NUM_EPOCH_WARM_UP = 5
NUM_BATCH_WARM_UP = NUM_EPOCH_WARM_UP * len(train_loader)
backbone.train()
for epoch in range(start_epoch, n_epochs):
train_sampler.set_epoch(epoch)
for step, (img, label) in enumerate(train_loader):
if (epoch + 1 <= NUM_EPOCH_WARM_UP) and (
global_step + 1 <= NUM_BATCH_WARM_UP) and 1: # adjust LR for each training batch during warm up
warm_up_lr(global_step + 1, NUM_BATCH_WARM_UP, cfg.lr, optimizer)
total_label, norm_weight = dist_sample_classifer.prepare(label, optimizer)
# print('total_label:', total_label.shape)
# print('norm_weight:', norm_weight.shape)
features = backbone(img) # feature 在内部归一化了
# Features all-gather
total_features = torch.zeros(features.size()[0] * world_size, cfg.embedding_size, device=local_rank)
###dist.all_gather(list(total_features.chunk(world_size, dim=0)), features.data)
recv_total_features = torch.zeros([features.size()[0] * world_size, cfg.embedding_size], device=features.data.device, dtype=features.data.dtype)
print(type(recv_total_features))
print(type(features.data))
allgather(features.data, recv_total_features) ###
total_features.requires_grad = True
# Calculate logits
# print('&' * 60)
# print('total_features:', total_features.shape)
# print('norm_weight:', norm_weight.shape)
# print('total_label:', total_label.shape)
logits = dist_sample_classifer(total_features, norm_weight) # cos =
# print('logits1:', logits.shape)
# print('logits:', logits.shape)
# print('&' * 60)
logits = margin_softmax(logits, total_label)
# print('logits2:', logits.shape)
# total_logits = torch.zeros(logits.size()[0], len(DataLoaderX), device=local_rank)
# dist.all_gather(list(total_logits.chunk(world_size, dim=0)),
# logits.data)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
# print('max_fc:', max_fc.shape)
###dist.all_reduce(max_fc, dist.ReduceOp.MAX)
recv_max_fc = torch.zeros_like(max_fc)
bagua.allreduce(max_fc, recv_max_fc, bagua.ReduceOp.MAX, comm=comm) ###
# print('#'*10, max_fc)
# Calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - recv_max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
recv_logits_sum_exp = torch.zeros_like(logits_sum_exp)
###dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
bagua.allreduce(logits_sum_exp, recv_logits_sum_exp, bagua.ReduceOp.SUM, comm=comm) ###
# Calculate prob
logits_exp.div_(logits_sum_exp)
# Get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(index.size()[0], grad.size()[1], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# Calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
recv_loss = torch.zeros_like(loss)
###dist.all_reduce(loss, dist.ReduceOp.SUM)
bagua.allreduce(loss, recv_loss, bagua.ReduceOp.SUM, comm=comm) ###
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# Calculate grad
grad[index] -= one_hot
grad.div_(features.size()[0])
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad = torch.zeros_like(features)
# Feature gradient all-reduce
###dist.reduce_scatter(
###x_grad, list(total_features.grad.chunk(world_size, dim=0)))
send_tensors = total_features.grad.chunk(world_size, dim=0)
send_tensor_bagua = torch.cat(send_tensors)
bagua.reduce_scatter(send_tensor_bagua, x_grad, comm=comm) ###
x_grad.mul_(world_size)
# Backward backbone
features.backward(x_grad)
optimizer.step()
# Update classifer
dist_sample_classifer.update()
optimizer.zero_grad()
tm = time.asctime().split()[-2]
if rank == 0 and global_step % cfg.disp_freq == 0:
writer.add_scalar('loss', loss_v, global_step)
print('\nEpoch:{}/{} Batch:{}/{}\t'
'Loss:{loss:.4f}\t'
'lr: {lr:.4f}\t'
'TimeNow:{TimeNow}'.format(
epoch, n_epochs, (step + 1) % len(train_loader), len(train_loader),
loss=loss_v, lr=optimizer.param_groups[0]['lr'], TimeNow=tm))
global_step += 1
scheduler.step()
if rank == 0:
os.makedirs(backbone_path, exist_ok=True)
state = {'backbone': backbone.module.state_dict(),
'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state,
backbone_path + "Backbone_IR_SE_50_Epoch_{}_Time_{}.pth".format(epoch + 1, get_time()))
os.makedirs(head_path, exist_ok=True)
if rank == 0:
state = {'head': dist_sample_classifer.state_dict()}
torch.save(state,
os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head0.pth".format(epoch + 1, get_time())))
if rank == 1:
state = {'head': dist_sample_classifer.state_dict()}
torch.save(state,
os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head1.pth".format(epoch + 1, get_time())))
if rank == 2:
state = {'head': dist_sample_classifer.state_dict()}
torch.save(state,
os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head2.pth".format(epoch + 1, get_time())))
if rank == 3:
state = {'head': dist_sample_classifer.state_dict()}
torch.save(state,
os.path.join(head_path, "Head_Arcface_Epoch_{}_Time_{}_head3.pth".format(epoch + 1, get_time())))
###dist.destroy_process_group()
if name == "main":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--local_rank', type=int, default=1, help='local_rank')
args = parser.parse_args()
os.environ['NCCL_DEBUG'] = 'INFO'
world_size = int(os.environ.get('WORLD_SIZE', 1))
print("lixiang")
env = {
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(world_size),
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": str(find_free_port(8000, 8100)),
"BAGUA_SERVICE_PORT": str(find_free_port(9000, 9100)),
}
print("lixiang2")
_init_bagua_env(bagua.get_local_rank(), env)
rank = bagua.get_rank()
print(rank)
main(bagua.get_local_rank(), rank, world_size, cfg)
bagua does not use the .module
attribute. You can call model.state_dict()
directly on line 289
Feel free to reopen if there are further problems :)