yaoyao-liu / meta-transfer-learning

TensorFlow and PyTorch implementation of "Meta-Transfer Learning for Few-Shot Learning" (CVPR2019)

Home Page:https://lyy.mpi-inf.mpg.de/mtl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Where the fine-tuning in the code

JianqunZhang opened this issue · comments

Hi Liu,i appreciate your work very much. In the process of my study, there is a difficult problem to understand. Can tou give me some help. I think fine-tuning should be implemented using test data sets. However, there seems to be no tuning step for the test data in the code. Can you tell me where the fine tuning steps have been implemented? In addition, i use your pytorch code.

Thanks for your interest in our work.

Could you please which test data you refer to? We have meta-train and meta-test stages. In each stage, we will sample episodes. Thus we have episode train and episode test. Do you mean fine-tuning on episode test data?

Thank you for your timely reply. Yes,i do mean fine-tuning on episode test data. I use the test data in the mini-imagement dataset under the meta-test stage. The support set in the test data should be fine tuned FTN + 1, is that right? Thank you very much!

The fine-tuning on the episode test is implemented by the following function:

def meta_forward(self, data_shot, label_shot, data_query):
"""The function to forward meta-train phase.
Args:
data_shot: train images for the task
label_shot: train labels for the task
data_query: test images for the task.
Returns:
logits_q: the predictions for the test samples.
"""
embedding_query = self.encoder(data_query)
embedding_shot = self.encoder(data_shot)
logits = self.base_learner(embedding_shot)
loss = F.cross_entropy(logits, label_shot)
grad = torch.autograd.grad(loss, self.base_learner.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.base_learner.parameters())))
logits_q = self.base_learner(embedding_query, fast_weights)
for _ in range(1, self.update_step):
logits = self.base_learner(embedding_shot, fast_weights)
loss = F.cross_entropy(logits, label_shot)
grad = torch.autograd.grad(loss, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = self.base_learner(embedding_query, fast_weights)
return logits_q

If you have further questions, feel free to add more comments to this issue.

Thank you very much! I got it! Thank you!