facebookresearch / barlowtwins

PyTorch implementation of Barlow Twins.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improvement over filtering bias and bn

towzeur opened this issue · comments

An improvement can be make in main.py :
Instead of recomputing every time the filtered parameters in LARS
one can pass two function weight_decay_filter and lars_adaptation_filter to the filtered group that always return True (filter).

def exclude_param(p: torch.nn.parameter.Parameter):
    return True
parameters = [
        {"params": param_weights},
        {
            "params": param_biases,
            "weight_decay_filter": exclude_param,
            "lars_adaptation_filter": exclude_param,
        },
    ]

then

optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay)

In LARS.step , the weights group will not be filtered as its weight_decay_filter and lars_adaptation_filter will be set to None
unlike the param_biases group.

Hello @towzeur,

thank you for the suggestion. While I agree that your proposed solution is a bit nicer than the code in the current repo, I don't think we will be making the proposed change, since the main purpose of this repository is to reproduce the results of the paper.

Best,
Jure