Feature request: Add Llama-style MLP with three linear layers
rationalism opened this issue · comments
Llama and several other popular open-source models use an MLP design with three linear layers:
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
(layer norm isn't included here, but it's applied before MLP as part of LlamaDecoderLayer)
However, the current TransformerEngine MLP module only has two linear layers:
It would be cool to add a Llama-style module with three layers and a SiLU activation function, since then models like Llama, Mixtral, Qwen, etc. could be ported over easily. Thanks :)
Hi @rationalism, Llama is actually supported by TE's LayerNormMLP module via the swiglu
activation. For performance reasons we fuse the 2 Linear layers into a single one. I recommend looking into the Llama tutorial we posted recently: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/te_llama/tutorial_accelerate_hf_llama_with_te.html