oscarknagg / few-shot

Repository for few-shot learning machine learning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

two questions

Interesting6 opened this issue · comments

Dear Oscar Knagg,
Firstly, I thank for your helpful codes. However, I tried to download your scripts, run again on my computer and raised an error at

few_shot/core.py", line 80, in iter
query = df[(df['class_id'] == k) & (~df['id'].isin(support_k[k]['id']))].sample(self.q)
ValueError: a must be greater than 0

Besides, in the maml, I see

order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the query set)
or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated weights on the query with respect to the original weights).

but the corresponding program is:
`python
if order == 1:
if train:
sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
for k in task_gradients[0].keys()}
hooks = []
for name, param in model.named_parameters():
hooks.append(
param.register_hook(replace_grad(sum_task_gradients, name))
)

    model.train()
    optimiser.zero_grad()
    # Dummy pass in order to create `loss` variable
    # Replace dummy gradients with mean task gradients using hooks
    logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
    loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
    loss.backward()
    optimiser.step()

    for h in hooks:
        h.remove()

return torch.stack(task_losses).mean(), torch.cat(task_predictions)

elif order == 2:
model.train()
optimiser.zero_grad()
meta_batch_loss = torch.stack(task_losses).mean()

if train:
    meta_batch_loss.backward()
    optimiser.step()

`
aren't the "order==2" part doing the "update meta-learner weights with gradients of the updated weights on the query set" job?

and I can't understand your code in "order==1" part. Can you explain it with more details again?

Any response would be appreciated! Thanks for your time!