Request for Adaptive Layer Norm MLP
fordflip opened this issue · comments
It'd be amazing to have support for a pytorch LayerNormMLP implementation that supports a scale and offset tensor to be applied after the layernorm but before the MLP. Would be curious to hear what it would take to implement this! happy to help
Hi, could you give a little bit more info on the usecase? LayerNorm already has scale and offset in weight
and bias
- why do you need an addiitonal set of those parameters?
For things like DiT https://arxiv.org/abs/2212.09748. Adaptive Layernorm (adaptive to a condition vector of some sort)
Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.
A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward
function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight
and self.layer_norm_bias
here (use scale
, not 1+scale
, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551
A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale
. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma
in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale
and not 1+scale
.