NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Request for Adaptive Layer Norm MLP

fordflip opened this issue · comments

It'd be amazing to have support for a pytorch LayerNormMLP implementation that supports a scale and offset tensor to be applied after the layernorm but before the MLP. Would be curious to hear what it would take to implement this! happy to help

Hi, could you give a little bit more info on the usecase? LayerNorm already has scale and offset in weight and bias - why do you need an addiitonal set of those parameters?

For things like DiT https://arxiv.org/abs/2212.09748. Adaptive Layernorm (adaptive to a condition vector of some sort)

Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.

A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight and self.layer_norm_bias here (use scale, not 1+scale, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551

A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale and not 1+scale.

would also be very interested in both an AdaLayerNorm and AdaLayerNormMLP, or alternatively a fused MLP without norm included, as requested in #817

Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.

A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight and self.layer_norm_bias here (use scale, not 1+scale, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551

A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale and not 1+scale.

Would this be a correct implementation?
I don't know if the _LayerNorm forward methods can take batch sizes though, so I just added a for loop (which is probably pretty slow)

class AdaLayerNorm(LayerNorm):
    def __init__(self, hidden_size, cond_hidden_size, scale=True, shift=True, bias=True, **kwargs):
        assert not (not scale and not shift)
        kwargs["zero_centered_gamma"] = True
        super().__init__(hidden_size, **kwargs)
        weight = torch.zeros_like(self.weight)
        bias = torch.zeros_like(self.bias)
        del self.weight, self.bias
        self.register_buffer("weight", weight)
        self.register_buffer("bias", bias)
        if scale and shift:
            self.c_proj = Linear(cond_hidden_size, hidden_size * 2, bias=bias)
        else:
            self.c_proj = Linear(cond_hidden_size, hidden_size, bias=bias)
        self.scale = scale
        self.shift = shift
    
    @no_torch_dynamo()
    def forward(self, x, cond):
        # Set the activation type for AMP.
        TransformerEngineBaseModule.set_activation_dtype(self, x)
        
        assert len(cond.shape) <= 2
        if len(cond.shape) == 1:
            cond = cond.unsqueeze(0).expand(x.shape[0], cond.shape[-1])
        embs = self.c_proj(cond)

        if torch.is_grad_enabled():
            fwd_fn = _LayerNorm.apply
        else:
            fwd_fn = _LayerNorm.forward
            
        out = []
        for emb in embs:
            if self.scale and self.shift:
                scale, shift = emb.chunk(2, dim=-1)
            elif self.scale:
                scale, shift = emb, self.bias
            elif self.shift:
                scale, shift = self.weight, emb
            if torch.is_grad_enabled():
                args = []
            else:
                args = [None]
            args += (
                inp,
                scale,
                shift,
                self.eps,
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
                self.inf_ln_sm_margin,
                self.zero_centered_gamma,
                torch.is_grad_enabled(),
                self.activation_dtype,
            )
            out.append(fwd_fn(*args))
        
        return torch.stack(out)

yeah, considering that LayerNorm doesn't seem to be able to take batch size on the conditioning, this probably needs to be implemented a different way since the for loop is slow

Hm, interesting - I looked at the HuggingFace implementation here: https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/normalization.py#L28. It basically computes the weight and bias of LayerNorm rather than keeping them as parameters.

A quick and dirty implementation would basically take the LayerNormMLP implementation from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py, added the code computing weight and bias (scale and shift in HF implementation) in forward function here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1496 and pass them instead of self.layer_norm_weight and self.layer_norm_bias here (use scale, not 1+scale, see note below): https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/module/layernorm_mlp.py#L1551

A random note on numerical precision: the code from HF is multiplying the LN output by 1+scale. This is not great if you use e.g. BF16 since this addition would be performed in that precision then which loses some precision (BF16 as all floating-point formats preserve the most precision around 0, not 1). That is why we introduced the option zero_centered_gamma in our LN implementation, which takes the weight and adds 1 to it inside the LayerNorm kernel in FP32 precision. That's why I would enable that option and pass just the scale and not 1+scale.

about that numerical precision thing, I'm trying to make a compiled version of it and I'm confused as to whether only the addition is performed in fp32 or the multiplication too?

@torch.compile
def ada_rms_norm(x: torch.Tensor, n_weight: torch.Tensor):
    B, D = x.shape[0], x.shape[-1]
    scale = D ** 0.5    
    n_weight = n_weight.view(B, *((1,) * (len(x.shape)-2)), -1)

    return F.normalize(x, dim=-1) * (1 + n_weight.to(dtype=torch.float32)).to(dtype=x.dtype) * scale