yaoyao-liu / meta-transfer-learning

TensorFlow and PyTorch implementation of "Meta-Transfer Learning for Few-Shot Learning" (CVPR2019)

Home Page:https://lyy.mpi-inf.mpg.de/mtl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to fix feature encoder weights during SS process

mengruwg opened this issue · comments

during pythorch version
new_weight = self.weight.mul(new_mtl_weight)(line 95 in conv2d_mtl.py)
self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))(line 42 in conv2d_mtl.py)
How to load pretrained and fix feature encoder weights during SS process?

How to freeze the convolution weights?

If you set mtl=True in the following line:

def __init__(self, layers=[4, 4, 4], mtl=True):

It means that you're using _ConvNdMtl function, so

self.weight.requires_grad=False

i.e., the convolution weights are frozen.

How to load the pre-trained model?

You may directly load normal checkpoints to MTL models like this:

pretrained_dict = torch.load(self.args.init_weights)['params']

If you have any further questions, feel free to ask.