oscarknagg / few-shot

Repository for few-shot learning machine learning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Possible solution to LSTM concatenation problem

rtorrisi opened this issue · comments

Hi Oscar,
firstly, thank you for sharing your code with all of us. I noticed you encountered the same trouble I had while I was trying to implement the f Full Context Embedding following the original paper description.

# h_hat, c = self.lstm_cell(queries, (torch.cat([h, readout], dim=1), c))

I build a LSTM model from scratch that could allow you to fix this problem.
You can find the model here: https://github.com/rtorrisi/LSTM-Custom-InOut-Hidden-Size
Let me know your thoughts.

Kind regards,
Riccardo

I tested the model you published with Omniglot dataset with following parameters:
python -m experiments.matching_nets --dataset omniglot --fce True --k-test 5 --n-test 1 --distance cosine. I also changed the value of episodes_per_epoch from 100 to 500.

I can confirm that using the LSTM I proposed, which allow to use concatenation, the model gain some performance on both validation (~2%) and train as shown in the following plot:

  • Blue and Orange lines are train and validation using the sum of h and readout.
  • Green and Red lines are train and validation using the concatenation of h and readout.

image