two questions
Interesting6 opened this issue · comments
Dear Oscar Knagg,
Firstly, I thank for your helpful codes. However, I tried to download your scripts, run again on my computer and raised an error at
few_shot/core.py", line 80, in iter
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
ValueError: a must be greater than 0
Besides, in the maml, I see
order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set)
or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights).
but the corresponding program is:
`python
if order == 1:
if train:
sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
for k in task_gradients[0].keys()}
hooks = []
for name, param in model.named_parameters():
hooks.append(
param.register_hook(replace_grad(sum_task_gradients, name))
)
model.train()
optimiser.zero_grad()
# Dummy pass in order to create `loss` variable
# Replace dummy gradients with mean task gradients using hooks
logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
loss.backward()
optimiser.step()
for h in hooks:
h.remove()
return torch.stack(task_losses).mean(), torch.cat(task_predictions)
elif order == 2:
model.train()
optimiser.zero_grad()
meta_batch_loss = torch.stack(task_losses).mean()
if train:
meta_batch_loss.backward()
optimiser.step()
`
aren't the "order==2" part doing the "update meta-learner weights with gradients of the updated weights on the query set" job?
and I can't understand your code in "order==1" part. Can you explain it with more details again?
Any response would be appreciated! Thanks for your time!