mistralai / mistral-inference

Official inference library for Mistral models

Home Page:https://mistral.ai/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about Mixtral MLP section

lhallee opened this issue · comments

Hello,

Great work! Is it okay to say it is just a standard vanilla MLP block? According to the huggingface implementation there is an additional third linear layer and added elementwise multiplication.

image

class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # not standard

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) # not standard
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

I think this has been confusing to some readers, but perhaps this has been used before and I am unaware. Is there any insights you guys can offer about why this layer was added? It seems to add more expressiveness to the experts but I didn't know if you had experimented with and without it.

a normal swiglue here (mlp)

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

I mean, it is a normal, i.e., vanilla swigule here, not a norm

a normal swiglue here (mlp)

This is showing up more often but using the w3 is definitely not the norm?

I mean, it is a normal, i.e., vanilla swigule here, not a norm

I meant "normal" not norm, sorry. Where is a swiglue mentioned in papers? Most transformers do not have three Linear layers in the MLP, including the original / vanilla transformer.