databricks / megablocks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to add support for swiglu in Megablocks?

fanshiqing opened this issue · comments

Hi @tgale96 @deepakn94,
During integration of Megablocks in to Megatron-LM, I found that when --swiglu (src) is enabled, the corresponding ffn_hidden_size is int((4 * args.hidden_size * 2 / 3) / 64) * 64 instead of general args.ffn_hidden_size = 4 * args.hidden_size (src), this change will break the underlying assert in dmoe.py:

assert self.ffn_hidden_size % self.blocking == 0

Could we add support for the case for swiglu where usually has self.ffn_hidden_size % self.blocking !=0 ? Thanks!

Yes, we actually have someone looking into swiglu support this week, funnily enough.

Fwiw, you can get around that assertion by using the grouped MLP rather than the sparse MLP. You can do that by enabling this flags.

Thanks Trevor.
One quick question: what's the difference of spraseMLP and GroupedMLP in Megablocks? What's their respective applicable scenarios(i.e., in which case should I enabled grouped_mlp=True for Megablocks?)?

It seems groupedMLP is one new added feature leveraging CUTLASS Grouped gemm last month, do we have any perf comparison between this grouped gemm and the previous block-sparse gemm?

Are there any known constraints when using groupedGEMM?

SparseMLP is what's described in the paper - it's still our recommended path, I would say.

GroupedMLP is an alternative implementation based on a grouped GEMM kernel. We added it to work around the fact that Triton performance on Hopper is currently not very good, but the lack of SM90-optimzed grouped GEMM kernels in CUTLASS means that its performance is currently good mostly for the small expert per GPU regime (e.g., when you're training with expert model parallelism).

Using a grouped GEMM can be a little simpler - there are fewer constraints on the matrix dimension (e.g., 128-alignment with SparseMLP), no padding on the activation tensor and the metadata used to describe the sparsity is simpler (just token counts per expert). We're in touch with the CUTLASS team and once an SM90-optimized kernel is available (expected later this year) we should be able to get comparable performance to SparseMLP across the problem space and across different architectures.

I don't have an in-depth performance evaluation of the two yet (we're primarily using them for different scenarios/hardware platforms). Prior benchmarks of CUTLASS suggested we might lose a small amount of performance relative to the block-sparse approach but I need to study in more depth once the proper kernels are available and integrated.

One thing to note is that our SparseMLP can use less memory when memory_optimized_mlp is enabled. We're working on understanding why that is, since I expect both should be equivalent (and the same memory optimizations are supported for both). It seems like GroupedMLP interacts differently with PyTorch's caching allocator.

I hope this helps! Let us know if there is anything we can do to help with your use case.

Hello! FYI, we just merged swiglu support in for SparseMLP and GroupedMLP.

Hello! FYI, we just merged swiglu support in for SparseMLP and GroupedMLP.

I cannot find the support for swiglu in this repository, could you please help point out how to enable swiglu?

In your Arguments object, set mlp_type="glu" :)

In your Arguments object, set mlp_type="glu" :)

That is actually not swiglu. But I have implemented myself. Thanks a lot anyway!

Yes, it uses gelu instead of silu, I think. That's the difference you're concerned about?

@tgale96 I believe part of it is in addition to using silu, when using glu/swiglu, the ffn_hidden_size needs to be multiplied by 2 in the first mlp projection, like here: https://github.com/NVIDIA/Megatron-LM/blob/fab0bd693ec5be55b32c4f12a1ea44766ec63448/megatron/model/transformer.py#L92-L94

Hi, thanks! You mean double ffn_hidden_size like we do here?

@tgale96 I think for swiglu here should be args.ffn_hidden_size * 2

That'd be true if you fused the two matrices in the front half of the GLU. These implementations do not - you'll notice they have three weight matrices, the first two of which (w1/v1) effectively double ffn_hidden_size.

I see, then it should be a matter of replacing gelu with silu for swiglu

Ya, we're going to make the activation function configurable for these ops. Fwiw, I don't expect silu versus gelu will make much of a difference given they're almost the same :)