NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature request: Add Llama-style MLP with three linear layers

rationalism opened this issue · comments

Llama and several other popular open-source models use an MLP design with three linear layers:

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

(layer norm isn't included here, but it's applied before MLP as part of LlamaDecoderLayer)

However, the current TransformerEngine MLP module only has two linear layers:

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.LayerNormMLP

It would be cool to add a Llama-style module with three layers and a SiLU activation function, since then models like Llama, Mixtral, Qwen, etc. could be ported over easily. Thanks :)

Hi @rationalism, Llama is actually supported by TE's LayerNormMLP module via the swiglu activation. For performance reasons we fuse the 2 Linear layers into a single one. I recommend looking into the Llama tutorial we posted recently: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/te_llama/tutorial_accelerate_hf_llama_with_te.html