Meta-Learning with Memory-Augmented Neural Networks in Tensorflow
A concise alternative Tensorflow Implementation of Papar Santoro, Adam, et al."Meta-learning with memory-augmented neural networks."International conference on machine learning. 2016. And the model are encapsulated into class MANNCell which can be used as BasicRNNCell. The code is inspired by the excellent implementations of tristandeleu and snowkylin.
Memory-Augmented Neural Networks
As shown in reference paper, MANNs(Memory-Augmented Neural Networks) refer to the class of external memory equipped networkds such as NTMs(Neural Turing Machines).
Dependencies
- Python 3.6
- Tensorflow==1.14
- numpy==1.16.4
- PIL==7.1.1
Usage
Omniglot DataSet
Download images_background.zip (964 classes) and images_evaluation.zip (679 classes), and place them in the ./omniglot folder.
Running
python run_mann.py
python run_mann.py --mode test
python run_mann.py --model LSTM
python run_mann.py --model LSTM --mode test
Class MANNCell()
from mann.mann_cell import MANNCell
cell = MANNCell(
lstm_size = 200,
memory_size = 128,
memory_dim = 40,
nb_reads = 4,
gamma = 0.95
)
state = cell.zero_state(batch_size, tf.float32)
output, state = tf.scan(lambda init, elem: cell(elem, init[1]), elems=tf.transpose(input, perm=[1, 0, 2]), initializer=(tf.zeros(shape=(batch_size, lstm_size+nb_reads*memory_dim)), state))
output = tf.transpose(output, perm=[1, 0, 2])
Performance
Omniglot Classfication:
Test-set classfication accuracies on the Omniglot dataset, using one-hot encodings of labels and five classes presented per episode.
Model | 1st | 2nd | 3rd | 4th | 5th | 10th |
---|---|---|---|---|---|---|
LSTMref | 24.4% | 49.5% | 55.3% | 61.0% | 63.6% | 62.5% |
LSTMrepo | 30.4% | 77.9% | 85.3% | 87.5% | 88.8% | 91.6% |
MANNref | 36.4% | 82.8% | 91.0% | 92.6% | 94.9% | 98.1% |
MANNrepo | 35.4% | 89.2% | 95.2% | 96.3% | 96.9% | 97.8% |