请教一个问题,使用mp_size=2时的loss应该怎么写
kunden0612 opened this issue · comments
logits, *mems = model(inputs_ids, position_ids, attention_mask)
# print(logits.shape)
loss_func = CrossEntropyLoss(ignore_index=-100)
loss = loss_func(logits.view(-1, logits.size(-1)).float(), labels.view(-1))``
我是这样写的loss计算方式,会出现一个/opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [15,0,0] Assertion t >= 0 && t < n_classes
failed.`` 错误
是不是你forward的时候传了parallel_output=True
:https://github.com/THUDM/SwissArmyTransformer/blob/main/sat/transformer_defaults.py#L146
导致输出还没有聚合,分散在多个rank里