microsoft / mup

maximal update parametrization (µP)

Home Page:https://arxiv.org/abs/2203.03466

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MuAdam not adjusting lr for output weights

zhuzilin opened this issue · comments

Hi, thank you for your great project for hyperparameter tuning!

As our team migrating the mup to other training framework, it occurs to us that the MuAdam does not scale the learning rate for output weights as the TP5 paper illustrated:
image

mup/mup/optim.py

Lines 55 to 70 in c9d6700

for p in param_group['params']:
assert hasattr(p, 'infshape'), (
f'A parameter with shape {p.shape} does not have `infshape` attribute. '
'Did you forget to call `mup.set_base_shapes` on the model?')
if p.infshape.ninf() == 2:
matrix_like_p[p.infshape.width_mult()]['params'].append(p)
elif p.infshape.ninf() > 2:
raise NotImplementedError('more than 2 inf dimensions')
else:
vector_like_p['params'].append(p)
for width_mult, group in matrix_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] /= width_mult
group['weight_decay'] *= width_mult
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
return impl(new_param_groups, **kwargs)

It seems to us that only the lr of hidden layer (the layer with 2 inf dimensions) is scaled w.r.t fanin, but the output weight is ignored. We wonder if this is intended. Thank you!

Hi zhuzilin,

Thanks for your question.

There are many equivalent ways to implement muP, and you are right that what is implemented in this package is not described by the table you attached. Instead, you want to look at Table 8.
image

We also noted in the caption of the table you attached that "also see Table 8 for a µP formulation that is easier to implement
(and compatible with input/output weight sharing)." Please let us know if this answers your question!

@edwardjhu Thank you edward! You answer solved my confusion. Just for a double check, if I need to implement a custom output layer, the table 8 means that I need to initialize the output weight with std 1 and always divide the output of the layer with fanin, right?

That's right!

@zhuzilin I want to fill in more information here that may have been lost in the subcontext. We don't want you to use exactly std=1 and divide the output layer by exactly fanin. You should interpret the 1 as O(1) and fanin as O(fanin). In other words, this table just says that, when you double your fanin, the multiplier on the last layer should be halved, but the initialization should be unchanged. The exact numbers you use for initialization and the multiplier should be tuned from some base model. This discussion applies to all other parameters in the table.

Regarding the output layer specifically, we actually recommend you initializing it at 0 if possible (assuming you don't have tricky weight tying btw input/output weights). This should not affect the performance of your model after training, but it will typically improve the transfer quality. You can see section D.2 in the paper for more details.