ictnlp / DDRS-NAT

Code for NAACL2022 main conference paper "One Reference Is Not Enough: Diverse Distillation with Reference Selection for Non-Autoregressive Translation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

One Reference Is Not Enough: Diverse Distillation with Reference Selection for Non-Autoregressive Translation

This repository contains the source code for our NAACL 2022 main conference paper One Reference Is Not Enough: Diverse Distillation with Reference Selection for Non-Autoregressive Translation pdf. This code is implemented based on the open-source toolkit fairseq-0.10.2.

Requirements

This system has been tested in the following environment.

  • Python version = 3.8
  • Pytorch version = 1.7

Diverse Distillation

Perform diverse distillation to obtain a dataset containing multiple references. You can follow the instructions below to prepare the diverse distillation dataset for WMT14 En-De. Or you can directly download our diverse distillation dataset and jump to step 4.

Step 1: Follow instruction from Fairseq to prepare and preprocess the WMT14 En-De dataset, or download the preprocessed dataset here. Save the raw data to data/wmt_ende (train.en-de.{en,de}, valid.en-de.{en,de}, test.en-de.{en,de}). Save the processed data to data-bin/wmt14_ende_raw.

Step 2: Train 3 different autoregressive models by using 3 different seeds.

data_dir=data-bin/wmt14_ende_raw
save_dir=output/wmt14_ende_at
for seed in {1..3}
do
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py $data_dir \
    --dropout 0.1 --fp16 --seed $seed --save-dir $save_dir$seed \
    --arch transformer_wmt_en_de  --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
    --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --max-tokens 4096 --update-freq 1\
    --no-progress-bar --log-format json --log-interval 1000 --save-interval-updates 5000 \
    --max-update 150000 --keep-interval-updates 5 --keep-last-epochs 5
sh tools/average.sh $save_dir$seed
done

Step 3: Use each model to decode the training set, obtain three decoding results pred.1, pred.2, pred.3.

data_dir=data-bin/wmt14_ende_raw
save_dir=output/wmt14_ende_at
for seed in {1..3}
do
CUDA_VISIBLE_DEVICES=0 python generate.py $data_dir --path $save_dir$seed/average-model.pt --gen-subset train --beam 5 --batch-size 100 --lenpen 0.6 > out.$seed
grep ^H out.$seed | cut -f1,3- | cut -c3- | sort -k1n | cut -f2- > pred.$seed
done

Step 4: Concat the three decoding results with a special token <divide>, and then preprocess the diverse distillation dataset.

data_dir=data/wmt14_ende
dest_dir=data-bin/wmt14_ende_divdis

python tools/concat.py
mv train.divdis.de $data_dir/
cp $data_dir/train.en-de.en $data_dir/train.divdis.en
python preprocess.py --source-lang en --target-lang de \
        --trainpref $data_dir/train.divdis \
        --validpref $data_dir/valid.en-de \
        --testpref $data_dir/test.en-de \
        --destdir $dest_dir \
        --joined-dictionary --workers 32\

Reference Selection

Train a CTC model on the diverse distillation dataset with reference selection. We implement the loss functions in nat_loss.py.

Step 1: Apply reference selection to train the CTC model. Adjust --updata-freq if the number of GPU devices is not 8.

data_dir=data-bin/wmt14_ende_divdis
save_dir=output/wmt14ende_disdiv
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py $data_dir \
    --num-references 3 --ctc-ratio 3 --src-embedding-copy --fp16 --ddp-backend=no_c10d --save-dir $save_dir \
    --task translation_lev \
    --criterion ddrs_loss \
    --arch nonautoregressive_transformer \
    --noise full_mask \
    --optimizer adam --adam-betas '(0.9,0.98)'  \
    --lr 0.0005 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 10000 \
    --warmup-init-lr '1e-07' --activation-fn gelu \
    --dropout 0.2 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --pred-length-offset \
    --length-loss-factor 0.1 \
    --apply-bert-init \
    --log-format 'simple' --log-interval 1000 \
    --max-tokens 4096 --update-freq 1\
    --save-interval-updates 5000 \
    --max-update 300000 --keep-interval-updates 5 --keep-last-epochs 5
sh tools/average.sh $save_dir

Step 2: Finetune the CTC model with the max-reward reinforcement learning or the newly proposed NMLA training objective. In practice, we find NMLA performs much better than max-reward reinforcement learning.

Finetune with NMLA:

data_dir=data-bin/wmt14_ende_divdis
save_dir=output/wmt14ende_disdiv
mkdir ${save_dir}tune
cp $save_dir/average-model.pt ${save_dir}tune/checkpoint_last.pt
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py $data_dir \
    --tune --use-ngram --reset-optimizer --num-references 3 --ctc-ratio 3 --src-embedding-copy --fp16 --ddp-backend=no_c10d --save-dir ${save_dir} \
    --task translation_lev \
    --criterion ddrs_loss \
    --arch nonautoregressive_transformer \
    --noise full_mask \
    --optimizer adam --adam-betas '(0.9,0.98)'  \
    --lr 0.0003 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 500 \
    --warmup-init-lr '1e-07'  --activation-fn gelu \
    --dropout 0.1 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --pred-length-offset \
    --apply-bert-init \
    --log-format 'simple' --log-interval 1 \
    --max-tokens 2048 --update-freq 16\
    --save-interval-updates 500 \
    --max-update 6000 --keep-interval-updates 5 --keep-last-epochs 5

Finetune with max-reward reinforcement learning:

data_dir=data-bin/wmt14_ende_divdis
save_dir=output/wmt14ende_disdiv
mkdir ${save_dir}tune
cp $save_dir/average-model.pt ${save_dir}tune/checkpoint_last.pt
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py $data_dir \
    --tune --reset-optimizer --num-references 3 --ctc-ratio 3 --src-embedding-copy --fp16 --ddp-backend=no_c10d --save-dir ${save_dir}tune \
    --task translation_lev \
    --criterion ddrs_loss \
    --arch nonautoregressive_transformer \
    --noise full_mask \
    --optimizer adam --adam-betas '(0.9,0.98)'  \
    --lr 0.00002 --lr-scheduler inverse_sqrt \
    --min-lr '1e-09' --warmup-updates 500 \
    --warmup-init-lr '1e-07' --activation-fn gelu \
    --dropout 0.1 --weight-decay 0.01 \
    --decoder-learned-pos \
    --encoder-learned-pos \
    --pred-length-offset \
    --length-loss-factor 0.1 \
    --apply-bert-init \
    --log-format 'simple' --log-interval 100 \
    --max-tokens 4096 --update-freq 1\
    --save-interval-updates 500 \
    --max-update 3000 --keep-interval-updates 5 --keep-last-epochs 5

Inference

Step 1: Decode the test set with argmax decoding.

model=output/wmt14ende_disdivtune/checkpoint_last.pt
data_dir=data-bin/wmt14_ende_divdis
CUDA_VISIBLE_DEVICES=0 python generate.py $data_dir \
    --gen-subset test \
    --task translation_lev \
    --iter-decode-max-iter  0  \
    --iter-decode-eos-penalty 0 \
    --path $model \
    --beam 1  \
    --left-pad-source False \
    --batch-size 100 > out
grep ^H out | cut -f1,3- | cut -c3- | sort -k1n | cut -f2- > pred.raw
python tools/dedup.py
python tools/deblank.py
sed -r 's/(@@ )|(@@ ?$)//g' pred.deblank > pred.de
perl tools/multi-bleu.perl ref.de < pred.de

Step 2: We can also apply beam search decoding combined with a 4-gram language model to search the target sentence. First, install the ctcdecode package.

git clone --recursive https://github.com/MultiPath/ctcdecode.git
cd ctcdecode && pip install .

Notice that it is important to install MultiPath/ctcdecode rather than the original package. This version pre-computes the top-K candidates before running the beam-search, which makes the decoding much faster. Then, follow kenlm to train a target-side 4-gram language model and save it as wmt14ende.arpa. Finally, decode the test set with beam search decoding combined with a 4-gram language model.

model=output/wmt14ende_disdivtune/checkpoint_last.pt
data_dir=data-bin/wmt14_ende_divdis
CUDA_VISIBLE_DEVICES=0 python generate.py $data_dir \
    --use-beamlm \
    --beamlm-path ./wmt14ende.arpa \
    --alpha $1 \
    --beta $2 \
    --gen-subset test \
    --task translation_lev \
    --iter-decode-max-iter  0  \
    --iter-decode-eos-penalty 0 \
    --path $model \
    --beam 1  \
    --left-pad-source False \
    --batch-size 100 > out
grep ^H out | cut -f1,3- | cut -c3- | sort -k1n | cut -f2- > pred.raw
sed -r 's/(@@ )|(@@ ?$)//g' pred.raw > pred.de
perl tools/multi-bleu.perl ref.de < pred.de

The optimal choices of alpha and beta vary among datasets and can be found by grid-search.

Citation

If you find the resources in this repository useful, please cite as:

@inproceedings{ddrs,
  title = {One Reference Is Not Enough: Diverse Distillation with Reference Selection for Non-Autoregressive Translation},
  author= {Chenze Shao and Xuanfu Wu and Yang Feng},
  booktitle = {Proceedings of NAACL 2022},
  year = {2022},
}

About

Code for NAACL2022 main conference paper "One Reference Is Not Enough: Diverse Distillation with Reference Selection for Non-Autoregressive Translation"

License:MIT License


Languages

Language:Python 97.4%Language:Shell 1.2%Language:Cuda 0.7%Language:C++ 0.3%Language:Cython 0.2%Language:Perl 0.1%Language:Lua 0.1%