Question about the right hidden dim when using SwiGLU
Thytu opened this issue · comments
Context
In the SwiGLU-based MLP section, it is stated :
The SwiGLU-based MLP contains an additional learned matrix in its activation function. There the MLP block contains 3 matrices instead of the original 2. To preserve the total number of parameters in the MLP block the paper that introduces SwiGLU proposes to use dim_mlp = 4dim_attninstead of the typical dim_mlp = 8/3dim_attn. The The Case for Co-Designing Model Architectures with Hardware paper provides recommendations for finding the value of hidden dimension (h) that would lead to the best matmul performance, and if you used 8/3*h is likely to result in a much slower MLP block, because 1/3 will break all the alignments.
And later on, in the final recommendations for model sizing section:
The full recommendations are:
[...]
6. For SwiGLU search for the best performing hidden size close to 8/3*h
It took me a couple of readings to understand those statements, and I'm still not entirely certain of their meaning tbh.
It could just be me, as I'm learning all those informations but I do have a few questions about it:
Questions
Right hidden dim when using SwiGLU
Here you said that: "For SwiGLU search for the best performing hidden size close to 8/3*h
"
But earlier you've written "if you used 8/3*h
is likely to result in a much slower MLP block, because 1/3 will break all the alignments".
Does that mean that you still recommend using 8/3*h
despite it leading to slower MLP block?
1/3 breaking the alignments
Also regarding the statement "and if you used 8/3*h
is likely to result in a much slower MLP block, because 1/3 will break all the alignments."
I'm not sure what the
Note: It could greatly help newcomers like myself to have those small sections clarified for better understanding :)
@Thytu, please correct me if I'm wrong but I'm seeing you highlighting 2 independent things:
- 8/3 vs 1/3
- 8/3h vs close to 8/3h
correct?
-
The first one is easy, I was trying to say that it's the dividing by 3 that messes up the best shape, multiplying by 8 is unlikely to cause a big degradation in performance. So I probably should stick to saying that it's 8/3 instead of 1/3 - what that make things clear?
-
continue saying a close to 8/3*h, and the exact number should be measured using the provided script.
Let me draft a PR for you to review and then we can sort it out and then merge an improved version.
Please have a look at the proposed #46 and please make further suggestions if something is still unclear. Thank you.
(Continuing the discussion on #46)