Recalculating the activations in the backwards pass to conserve memory
ChrisDryden opened this issue · comments
@ngc92 Did an analysis of the areas that take up the most memory and its impact on the amount of batches that can be used and found that one of the largest contributors was the memory associated with the Layernorm recomputations:
Layernorm recomputation:
2 layernorms in each of the 48 layers, 1600*2 bytes per token
=>
9600 MiB = 2*48*32*1024*1600*2
Missing-bits for master weights: (PR exists)
optim-w from 32-bit to 16-bit
372 MiB
FP16 optimizer states: (PR exists, but very outdated)
optim-m, optim-v from 32-bit to 16-bit
743 MiB
ZeRO-2: (missing prerequisites: cuda streams, overlapping backward and comms)
grads / nGPU
2599 MiB
This Issue will track the implementation of adding the ability similar to how the GELU is recalculated in the backwards pass to recalculate the layernorm forwards activations so that we can reduce the memory.
To start off, I will first implement the layernorm forward in the backwards pass implementation and use the ln1 and ln2 values directly from that layernorm forward to get an initial working version of recalculating the values in the backwards pass.
In the above PR I was able to implement the reduced memory:
Went from this with recompute set to 1:
allocating 1439 MiB for activations
val loss 4.503491
allocating 237 MiB for parameter gradients
allocating 30 MiB for activation gradients
allocating 474 MiB for AdamW optimizer state m
allocating 474 MiB for AdamW optimizer state v
allocating 474 MiB for master copy of params
To this:
allocating 1307 MiB for activations
val loss 4.504488
allocating 237 MiB for parameter gradients
allocating 30 MiB for activation gradients
allocating 474 MiB for AdamW optimizer state m
allocating 474 MiB for AdamW optimizer state v
allocating 474 MiB for master copy of params
The PR was merged but still needs the second step of making a simplified kernel that doesnt recompute everything and reuses the values calculated in the forwards pass