hsakas / AoAReader

PyTorch implementation of Attention-over-Attention Neural Networks for Reading Comprehension

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Attention-over-Attention Model for Reading Comprehension

This is an implementation of Attention-over-Attention Model with PyTorch. This model was proposed by Cui et al. (paper).

Prerequisite

  • PyTorch with cuda
  • Python 3.6+
  • NLTK (with punkt data)

Usage

This implementation uses facebook’s children’s book test data.

Preprocessing

Make sure the data files (train.txt, dev.txt, test.txt) are present in the data directory.

To preprocess the data:

python preprocess.py

This will generate the dictonary(dict.pt) from all words appeared in the dataset and vectorize all data (train.txt.pt, dev.txt.pt, test.txt.pt).

Train the model

Below is an example of training a model, set the parameters as you like.

python train.py -traindata data/train.txt.pt -validdata data/test.txt.pt -dict data/dict.pt \
 -save_model model1 -gru_size 384 -embed_size 384 -batch_size 64 -dropout 0.1 \
 -epochs 13 -learning_rate 0.001 -weigth_decay 0.0001 -gpu 1 -log_interval 50

After each epoch, a checkpoint will be saved, to resume a training process from checkpoint:

python train.py -train_from xxx_model_xxx_epoch_x.pt

Testing

python test.py -testdata data/test.txt.pt -dict data/dict.pt -out result.txt -model models/xx_checkpoint_epochxx.pt

License

MIT License

About

PyTorch implementation of Attention-over-Attention Neural Networks for Reading Comprehension

License:MIT License


Languages

Language:Python 100.0%