stas00 / ml-engineering

Machine Learning Engineering Open Book

Home Page:https://stasosphere.com/machine-learning/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about the right hidden dim when using SwiGLU

Thytu opened this issue · comments

Context

In the SwiGLU-based MLP section, it is stated :

The SwiGLU-based MLP contains an additional learned matrix in its activation function. There the MLP block contains 3 matrices instead of the original 2. To preserve the total number of parameters in the MLP block the paper that introduces SwiGLU proposes to use dim_mlp = 4dim_attninstead of the typical dim_mlp = 8/3dim_attn. The The Case for Co-Designing Model Architectures with Hardware paper provides recommendations for finding the value of hidden dimension (h) that would lead to the best matmul performance, and if you used 8/3*h is likely to result in a much slower MLP block, because 1/3 will break all the alignments.

And later on, in the final recommendations for model sizing section:

The full recommendations are:
[...]
6. For SwiGLU search for the best performing hidden size close to 8/3*h

It took me a couple of readings to understand those statements, and I'm still not entirely certain of their meaning tbh.
It could just be me, as I'm learning all those informations but I do have a few questions about it:

Questions

Right hidden dim when using SwiGLU

Here you said that: "For SwiGLU search for the best performing hidden size close to 8/3*h"

But earlier you've written "if you used 8/3*h is likely to result in a much slower MLP block, because 1/3 will break all the alignments".

Does that mean that you still recommend using 8/3*h despite it leading to slower MLP block?

1/3 breaking the alignments

Also regarding the statement "and if you used 8/3*h is likely to result in a much slower MLP block, because 1/3 will break all the alignments."
I'm not sure what the $1/3$ refers to, clarify its meaning for me?


Note: It could greatly help newcomers like myself to have those small sections clarified for better understanding :)

@Thytu, please correct me if I'm wrong but I'm seeing you highlighting 2 independent things:

  1. 8/3 vs 1/3
  2. 8/3h vs close to 8/3h

correct?

  1. The first one is easy, I was trying to say that it's the dividing by 3 that messes up the best shape, multiplying by 8 is unlikely to cause a big degradation in performance. So I probably should stick to saying that it's 8/3 instead of 1/3 - what that make things clear?

  2. continue saying a close to 8/3*h, and the exact number should be measured using the provided script.

Let me draft a PR for you to review and then we can sort it out and then merge an improved version.

Please have a look at the proposed #46 and please make further suggestions if something is still unclear. Thank you.

(Continuing the discussion on #46)