[Feature] Parallel transformer block
xrsrke opened this issue · comments
Implement this:
Parallel transformer block [14]. We adopt a parallel version of the transformer block in lieu of the standard serialized formulation. Specifically, the standard formula of the transformer block can be reformatted from
$y=x+\text{MLP}(\text{LN}(x+\text{Attention}(\text{LN}(x))))$ into$y=x+\text{MLP}(\text{LN}(x))+\text{Attention}(\text{LN}(x))$ With this approach, the computation of the attention block and the MLP block can be executed in parallel, thereby reducing the computation time. Prior work [5] shows that this modification does not degrade the quality of models with parameters in the hundreds of billions.
Reference: MegaScale: Scaling Large Language Model Training to More Than 10,000 GPUs, page 3