databricks / megablocks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ParallelDroplessMLP initialises self.mlp twice

152334H opened this issue · comments

What the title says. In layers/dmoe.py:

class ParallelDroplessMLP(moe.ParallelMLP):

    def __init__(self, args : Arguments):
        super(ParallelDroplessMLP, self).__init__(args) # <-- first init!
        self.hidden_size = args.hidden_size
        self.ffn_hidden_size = mpu.features_per_rank(args)
        self.blocking = 128
        self.mlp = dmlp_registry.get(args) # <-- second init!

As a subclass of moe.ParallelMLP, ParallelDroplessMLP first initialises self.mlp in super().__init__() (at layers/moe.py):

class ParallelMLP(torch.nn.Module):

    def __init__(self, args : Arguments):
        # ... omitted ...

        # Expert MLP.
        self.mlp = mlp.MLP(args)

This causes extra initialisation time && init memory usage, as the weights created in this init are immediately overwritten by new weights created via self.mlp = dmlp_registry.get(args).

Apologies in advance if this double-init process is actually crucially important to the mechanics of the library; I personally did not observe anything breaking after commenting out the first initialisation.

Hi! This is a good point and we should probably fix it by refactoring weight initialization into a separate function that can be overloaded by the derived class. We would welcome a PR if you happen to have cycles to make the change! :)

This double init does not seem to affect memory usage? I printed the memory allocation before and after https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/dmoe_test.py#L41, although the MLP init and mlp_impl init are both called, the allocated memory is still hidden * intermediate * num_experts * bytes per param + router.

not sure why that happened for you; I get clear and obvious reduced memory when I initialise dMoE with the first inner self.mlp initialisation commented out.

Do you happen to have a repro for the increased memory usage?

repro:

from megablocks.layers.dmoe import ParallelDroplessMLP
from megablocks.layers.moe import mlp
from megablocks.layers.arguments import Arguments
import os, psutil, sys
import torch

# mixtral-like arguments. some configs disabled for speed.
args = Arguments(
    #hidden_size=4096, #1024
    #ffn_hidden_size=14336, #4096
    #num_layers=num_layers,
    bias=False, # True
    return_bias=False,# True
    #activation_fn=torch.nn.functional.silu, #DEFAULT_ACTIVATION_FN

    # MoE arguments.
    moe_num_experts=8, #1
    moe_top_k=2, #1
    mlp_type='glu', # 'mlp'
    #mlp_impl='grouped', # 'sparse'
    device='cpu', # torch.cuda.current_device()
)

if sys.argv[1] == 'x': mlp.MLP = lambda a:None # inject and force replace mlp.MLP(args) with a no-op
m = ParallelDroplessMLP(args)

You can use time -v to track the peak memory of the python process, which is higher with self.mlp = ...:

$ /usr/bin/time -v python3 lol.py y 2>&1 | grep Maximum
        Maximum resident set size (kbytes): 803400
$ /usr/bin/time -v python3 lol.py x 2>&1 | grep Maximum
        Maximum resident set size (kbytes): 666412

I assume @cli99 did not see this because they were tracking the final memory usage, presumably after the first self.mlp is garbage collected.

Thanks for the repro! This should be relatively easy to fix once we get some free cycles to do the work.