Improvement over filtering bias and bn
towzeur opened this issue · comments
An improvement can be make in main.py :
Instead of recomputing every time the filtered parameters in LARS
one can pass two function weight_decay_filter and lars_adaptation_filter to the filtered group that always return True (filter).
def exclude_param(p: torch.nn.parameter.Parameter):
return True
parameters = [
{"params": param_weights},
{
"params": param_biases,
"weight_decay_filter": exclude_param,
"lars_adaptation_filter": exclude_param,
},
]
then
optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay)
In LARS.step , the weights group will not be filtered as its weight_decay_filter and lars_adaptation_filter will be set to None
unlike the param_biases group.
Hello @towzeur,
thank you for the suggestion. While I agree that your proposed solution is a bit nicer than the code in the current repo, I don't think we will be making the proposed change, since the main purpose of this repository is to reproduce the results of the paper.
Best,
Jure