Leputa / MANN-meta-learning

A tensorflow implement of Memory-Augmented Neural Network

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Meta-Learning with Memory-Augmented Neural Networks in Tensorflow

A concise alternative Tensorflow Implementation of Papar Santoro, Adam, et al."Meta-learning with memory-augmented neural networks."International conference on machine learning. 2016. And the model are encapsulated into class MANNCell which can be used as BasicRNNCell. The code is inspired by the excellent implementations of tristandeleu and snowkylin.

Memory-Augmented Neural Networks

As shown in reference paper, MANNs(Memory-Augmented Neural Networks) refer to the class of external memory equipped networkds such as NTMs(Neural Turing Machines).

MANN

Dependencies

  • Python 3.6
  • Tensorflow==1.14
  • numpy==1.16.4
  • PIL==7.1.1

Usage

Omniglot DataSet

Download images_background.zip (964 classes) and images_evaluation.zip (679 classes), and place them in the ./omniglot folder.

Running

python run_mann.py
python run_mann.py --mode test
python run_mann.py --model LSTM
python run_mann.py --model LSTM --mode test

Class MANNCell()

from mann.mann_cell import MANNCell
cell = MANNCell(
    lstm_size = 200, 
    memory_size = 128,
    memory_dim = 40,
    nb_reads = 4,
    gamma = 0.95
)
state = cell.zero_state(batch_size, tf.float32)  
output, state = tf.scan(lambda init, elem: cell(elem, init[1]), elems=tf.transpose(input, perm=[1, 0, 2]), initializer=(tf.zeros(shape=(batch_size, lstm_size+nb_reads*memory_dim)), state))  
output = tf.transpose(output, perm=[1, 0, 2])

Performance

Omniglot Classfication:

LSTM MANN

Test-set classfication accuracies on the Omniglot dataset, using one-hot encodings of labels and five classes presented per episode.

Model 1st 2nd 3rd 4th 5th 10th
LSTMref 24.4% 49.5% 55.3% 61.0% 63.6% 62.5%
LSTMrepo 30.4% 77.9% 85.3% 87.5% 88.8% 91.6%
MANNref 36.4% 82.8% 91.0% 92.6% 94.9% 98.1%
MANNrepo 35.4% 89.2% 95.2% 96.3% 96.9% 97.8%

About

A tensorflow implement of Memory-Augmented Neural Network

License:MIT License


Languages

Language:Python 100.0%