MoE performs worse than equivalent dense model?
Muennighoff opened this issue · comments
Afaict for the numbers reported here #115 (comment) the "1 expert" model is still an MoE, correct?
I also get the result that the 8-expert MoE is better than the 1-expert one, however, both are worse than a dense model. In the below graph OpenLM-41M
is a 41M dense model and the above two are 8-expert & 1-expert models with 41M active parameters.
![Screenshot 2024-04-17 at 2 43 18 PM](https://private-user-images.githubusercontent.com/62820084/323387018-a42a4c15-39d6-465a-9de7-b1a6aaa37321.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE5MzM0OTUsIm5iZiI6MTcyMTkzMzE5NSwicGF0aCI6Ii82MjgyMDA4NC8zMjMzODcwMTgtYTQyYTRjMTUtMzlkNi00NjVhLTlkZTctYjFhNmFhYTM3MzIxLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzI1VDE4NDYzNVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTBmZjY5NmE1MWU0MmE3ZTg2NzUyNDVhY2I3MzU3YTk5ODIxY2Y3OWMzODQ2ZGJiZTFlYzFjZGQ2OGM4ODBkMDAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.PzM-IMESS4rLFkEoUh06VIWQQDm7ZVzCY_2UaIdFTOU)
I would expect the 1-expert to roughly match the dense one & the 8-expert to be better than both but maybe I am missing something? @kernelmachine @sagadre
(My setup follows the main README & https://github.com/mlfoundations/open_lm/blob/main/MOE.md)
Great catch, thanks @Muennighoff! I think this is because the MoE defaults from megablocks differ from our default dense model, in at least two ways: their ffn uses Gelu without gating (w2 @ gelu(w1 @ x)
) while ours uses swiglu (w3 @ (silu(w2 @ x) @ (w1 @ x))
), and we use a different parameter init function.
I'm not sure what the easiest way to reconcile these is. Probably we want a custom version of the MLP class here https://github.com/databricks/megablocks/blob/2724ff6775ee7e2a41001a7979c0ec84c417cd84/megablocks/layers/mlp.py#L81-L137 that implements swiglu and our init function.
I trained an OpenLM model with the Gelu & tanh approximate used in megablocks and regular normal init by adding the below two in model.py
:
elif args.ffn_type == "gelutanh":
self.hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="tanh"), self._ff_w2)
&
elif self._ffn_type == "gelutanh":
torch.nn.init.normal_(self._ff_w1.weight, mean=0.0, std=0.02)
torch.nn.init.normal_(self._ff_w2.weight, mean=0.0, std=0.02)
It does indeed perform slightly worse than the regular 41M but still far better than the MoE variants. (For the 8 expert one params increased from 69M ot 97M as I increased MoE frequency to every layer instead of every 2nd). Do you have any more ideas where this could come from?
![Screenshot 2024-04-18 at 11 49 13 AM](https://private-user-images.githubusercontent.com/62820084/323725789-0d01aab9-b753-4bcf-b9e8-dc647f1991bb.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjE5MzM0OTYsIm5iZiI6MTcyMTkzMzE5NiwicGF0aCI6Ii82MjgyMDA4NC8zMjM3MjU3ODktMGQwMWFhYjktYjc1My00YmNmLWI5ZTgtZGM2NDdmMTk5MWJiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzI1VDE4NDYzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWQwYWM2MTExYzkxOGJhNThlZTViNmZkMzgzODBkNzcwMjMxYmEwNjVhNDM5YTBjYmZjYjRhYmY4Yjc0MGI2MGYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.FrIA_feoUNPQ4aTdMMrP_GGo8KVRIehjYse47wi-STI)
For reference this is the code I am running:
No MoE:
torchrun --nproc-per-node 8 -m open_lm.main \
--train-data "/data/niklas/openlm/preproc/2048-v1/0/shard-{0000000..0000099}.tar" \
--train-num-samples 10000000000 \
--precision amp_bfloat16 \
--global-batch-size 64 \
--accum-freq 4 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 5e-4 \
--warmup 200 \
--model open_lm_41m \
--wd 0.1 \
--beta2 0.95 \
--epochs 50 \
--report-to wandb \
--wandb-project-name olmoe \
--name test$RANDOM \
--logs /data/niklas/openlm/moe \
--resume latest \
--seed 124 \
--data-key 'txt' \
--fsdp \
--fsdp-amp \
--model-norm gain_only_layer_norm \
--lr-scheduler cosine \
--lr-cooldown-end 0.00001 \
--ffn-type gelutanh
MoE w/ 8 experts:
torchrun --nproc-per-node 8 -m open_lm.main \
--train-data "/data/niklas/openlm/preproc/2048-v1/0/shard-{0000000..0000099}.tar" \
--train-num-samples 10000000000 \
--precision amp_bfloat16 \
--global-batch-size 64 \
--accum-freq 4 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 5e-4 \
--warmup 200 \
--model open_lm_41m \
--wd 0.1 \
--beta2 0.95 \
--epochs 50 \
--report-to wandb \
--moe-freq 1 \
--moe-num-experts 8 \
--moe-top-k 1 \
--moe-capacity-factor 1.25 \
--moe-loss-weight 0.1 \
--wandb-project-name olmoe \
--name test$RANDOM \
--logs /data/niklas/openlm/moe \
--resume latest \
--seed 124 \
--data-key 'txt' \
--fsdp \
--fsdp-amp \
--model-norm gain_only_layer_norm \
--lr-scheduler cosine \
--lr-cooldown-end 0.00001