THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.

Home Page:https://THUDM.github.io/SwissArmyTransformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

请教一个问题,使用mp_size=2时的loss应该怎么写

kunden0612 opened this issue · comments

logits, *mems = model(inputs_ids, position_ids, attention_mask)
# print(logits.shape)
loss_func = CrossEntropyLoss(ignore_index=-100)
loss = loss_func(logits.view(-1, logits.size(-1)).float(), labels.view(-1))``

我是这样写的loss计算方式,会出现一个/opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/cuda/Loss.cu:242: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [15,0,0] Assertion t >= 0 && t < n_classes failed.`` 错误

是不是你forward的时候传了parallel_output=Truehttps://github.com/THUDM/SwissArmyTransformer/blob/main/sat/transformer_defaults.py#L146

导致输出还没有聚合,分散在多个rank里