Question about Mixtral MLP section
lhallee opened this issue · comments
Hello,
Great work! Is it okay to say it is just a standard vanilla MLP block? According to the huggingface implementation there is an additional third linear layer and added elementwise multiplication.
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # not standard
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) # not standard
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
I think this has been confusing to some readers, but perhaps this has been used before and I am unaware. Is there any insights you guys can offer about why this layer was added? It seems to add more expressiveness to the experts but I didn't know if you had experimented with and without it.
a normal swiglue here (mlp)
a normal swiglue here (mlp)
This is showing up more often but using the w3 is definitely not the norm?
a normal swiglue here (mlp)
This is showing up more often but using the w3 is definitely not the norm?
I mean, it is a normal, i.e., vanilla swigule here, not a norm
a normal swiglue here (mlp)
This is showing up more often but using the w3 is definitely not the norm?
I mean, it is a normal, i.e., vanilla swigule here, not a norm
I meant "normal" not norm, sorry. Where is a swiglue mentioned in papers? Most transformers do not have three Linear layers in the MLP, including the original / vanilla transformer.