Hayao41 / var-attn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Latent Alignment and Variational Attention

This is a Pytorch implementation of the paper Latent Alignment and Variational Attention from a fork of OpenNMT.

Dependencies

The code was tested with python 3.6 and pytorch 0.4. To install the dependencies, run

pip install -r requirements.txt

Running the code

All commands are in the script va.sh.

Preprocessing the data

To preprocess the data, run

source va.sh && preprocess_bpe

The raw data in data/iwslt14-de-en was obtained from the fairseq repo with BPE_TOKENS=14000.

Training the model

To train a model, run one of the following commands:

  • Soft attention
source va.sh && CUDA_VISIBLE_DEVICES=0 train_soft_b6
  • Categorical attention with exact evidence
source va.sh && CUDA_VISIBLE_DEVICES=0 train_exact_b6
  • Variational categorical attention with exact ELBO
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_enum_b6
  • Variational categorical attention with REINFORCE
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_sample_b6

Checkpoints will be saved to the project's root directory.

Evaluating on test

The exact perplexity of the generative model can be obtained by running the following command with $model replaced with a saved checkpoint.

source va.sh && CUDA_VISIBLE_DEVICES=0 eval_cat $model

The model can also be used to generate translations of the test data:

source va.sh && CUDA_VISIBLE_DEVICES=0 gen_cat $model
sed -e "s/@@ //g" $model.out | perl tools/multi-bleu.perl data/iwslt14-de-en/test.en

Trained Models

Models with the lowest validation PPL were selected for evaluation on test. Numbers are slightly different from those reported in the paper since this is a re-implementation.

Model Test PPL Test BLEU
Soft 7.17 32.77
Exact 6.34 33.29
VAE Exact ELBO 6.08 33.69
VAE REINFORCE 6.17 33.30

About

License:MIT License


Languages

Language:Python 84.1%Language:Perl 7.2%Language:Shell 4.1%Language:Emacs Lisp 3.4%Language:Smalltalk 0.4%Language:Ruby 0.3%Language:NewLisp 0.3%Language:JavaScript 0.2%Language:Slash 0.1%Language:SystemVerilog 0.0%Language:Dockerfile 0.0%