µTransfer across batch size && weight decay setting
PanYue2023 opened this issue · comments
Hello there! I have reproduced the results of Transformer on your code and am now looking to apply mup to my own model. However, I have some doubts and would greatly appreciate it if you could explain them to me.
Question 1: µTransfer across batch size
In the paper, it is mentioned that the optimal learning rate can be transferred between different batch sizes, for example, in Figure 19. However, I couldn't find the relevant implementation in the provided examples.
I would like to confirm if the experiments were conducted by keeping all other hyperparameters fixed and only modifying the batch size. Regarding Init. Var. and LR, is it following the approach of width where a base batch size is set and then Init. Var. and LR are scaled according to the current batch size? If I want to simultaneously adjust the model's width and batch size while keeping the optimal learning rate unchanged, is it feasible?
Question 2: weight decay setting
Regarding the setting of weight decay, the explanation in the paper is also concise, and I am a bit confused about the implementation in AdamW. I noticed in the paper that weight decay should be independent of the width, and I found a code snippet on Hugging Face which uses mup. https://huggingface.co/cerebras/btlm-3b-8k-base/blob/main/modeling_btlm.py
def get_mup_param_groups(self, lr, weight_decay=0.0, decoupled_wd=True):
"""
Returns list of dicts defining parameter groups for muP:
group 0: most model params get scaled learning rate and weight decay.
group 1: embedding layer gets non-scaled learning rate and weight decay.
group 2: normalization layers and biases get non-scaled learning rate only.
The output can be passed to Adam-base optimizers
e.g.
param_groups = model.get_mup_param_groups(lr=1e-3, weight_decay=0.1)
torch.optim.AdamW(param_groups, betas=(0.9, 0.95), eps=1e-8)
"""
norm_modules = (
torch.nn.LayerNorm,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.GroupNorm,
torch.nn.SyncBatchNorm,
torch.nn.LocalResponseNorm,
)
def get_group_index(param_name):
for name, module in self.named_modules():
if name in param_name:
if isinstance(module, norm_modules):
return 2
elif isinstance(module, torch.nn.Embedding):
return 1
return 0
width_scale = self.config.mup_width_scale
new_param_groups = []
new_param_groups.append({"params": [], "lr": lr * width_scale, "weight_decay": weight_decay})
if not decoupled_wd:
new_param_groups[0]["weight_decay"] /= width_scale
new_param_groups.append({"params": [], "lr": lr, "weight_decay": weight_decay})
new_param_groups.append({"params": [], "lr": lr, "weight_decay": 0.0})
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if name.endswith("bias"):
new_param_groups[2]["params"].append(param)
else:
new_param_groups[get_group_index(name)]["params"].append(param)
for idx, param_group in enumerate(new_param_groups):
if len(param_group["params"]) == 0:
del new_param_groups[idx]
return new_param_groups
Is it following this approach to categorize the parameters into three types and set their respective learning rates and weight decays? And the decoupled_wd should be set to True so that weight decay will not be scaled?
Once again, I appreciate your valuable work and the significance it holds in this field. I am eagerly looking forward to your kind explanation.