NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Request for Adaptive Layer Norm MLP

fordflip opened this issue · comments

It'd be amazing to have support for a pytorch LayerNormMLP implementation that supports a scale and offset tensor to be applied after the layernorm but before the MLP. Would be curious to hear what it would take to implement this! happy to help

Hi, could you give a little bit more info on the usecase? LayerNorm already has scale and offset in weight and bias - why do you need an addiitonal set of those parameters?

For things like DiT https://arxiv.org/abs/2212.09748. Adaptive Layernorm (adaptive to a condition vector of some sort)

Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.

A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight and self.layer_norm_bias here (use scale, not 1+scale, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551

A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale and not 1+scale.