Attention-over-Attention Model for Reading Comprehension
This is an implementation of Attention-over-Attention Model with PyTorch. This model was proposed by Cui et al. (paper).
Prerequisite
- PyTorch with cuda
- Python 3.6+
- NLTK (with punkt data)
Usage
This implementation uses facebook’s children’s book test data.
Preprocessing
Make sure the data files (train.txt, dev.txt, test.txt) are present in the data
directory.
To preprocess the data:
python preprocess.py
This will generate the dictonary(dict.pt
) from all words appeared in the dataset and
vectorize all data (train.txt.pt
, dev.txt.pt
, test.txt.pt
).
Train the model
Below is an example of training a model, set the parameters as you like.
python train.py -traindata data/train.txt.pt -validdata data/test.txt.pt -dict data/dict.pt \
-save_model model1 -gru_size 384 -embed_size 384 -batch_size 64 -dropout 0.1 \
-epochs 13 -learning_rate 0.001 -weigth_decay 0.0001 -gpu 1 -log_interval 50
After each epoch, a checkpoint will be saved, to resume a training process from checkpoint:
python train.py -train_from xxx_model_xxx_epoch_x.pt
Testing
python test.py -testdata data/test.txt.pt -dict data/dict.pt -out result.txt -model models/xx_checkpoint_epochxx.pt