Implement vocab parallel cross entropy loss
bzantium opened this issue · comments
Minho Ryu commented
Describe a requested feature
- Implement cross entropy for vocab paralleled logits in tensor parallel 1D, 2D, 2p5D, 3D.
- Implement test codes.
Expected behavior
>>> criterion = VocabParallelCrossEntropyLoss(parallel_context)
>>> loss = criterion(vocab_parallel_logits, targets)