RenShuhuai-Andy / WMT18-English-Chinese-Machine-Translation

Machine translation experiments for WMT18 en-zh track

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WMT18-en-zh-Machine-Translation

Assignment for PKU Computational Linguistics 2020 fall.

Achieve WMT18 en-zh Machine Translation with Vanilla Transformer, Explicit Sparse Transformer, and PRIME (PaRallel Intersected Multi-scale AttEntion).

Introduction

Dataset

We conduct the en-zh translation experiment on a subset of WMT18 zh-en dataset. The statistics of our dataset is :

WMT18 en-zh # Examples Avg Length (en/zh) Max Length (en/zh)
Train 176,943 22.05 / 23.71 153/511
Dev 25,278 22.08 / 23.70 94 / 110
Test 50,556 22.05 / 23.70 116 / 111

Explicit Sparse Transformer

Explicit Sparse Transformer

PRIME

Prime

Comparison of vanilla attention and sparse attention

Prime

Prepare environment

conda create -n mt python=3.6
conda activate mt
conda install pytorch torchvision cudatoolkit=10.0 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

Process data

bash prepare-wmt18en-zh.sh
sh process.sh joint
  • joint denotes --joined-dictionary

Train

sh train_transformer.sh

Evaluation

sh evaluate.sh -m transformer_wmt_en_zh -c 0 -n 10
  • -m denotes the model architecture.
  • -c denotes the index of CUDA device.
  • -n denotes the number of checkpoints for average.

Model Architecture

Model Architecture (Sparse) Transformer-Base (Sparse) Transformer-Big Prime
Encoder Embedding Size 512 512 384
Encoder Feed-forward Size 1024 2048 768
Encoder Attention Head Size 4 8 4
Encoder Layer Number 4 6 8
Decoder Embedding Size 512 512 384
Decoder Feed-forward Size 1024 2048 768
Decoder Attention Head Size 4 8 4
Decoder Layer Number 4 6 8

Result

BLEU
transformer_wmt_en_zh 21.56, 57.8/28.5/16.1/9.6
+ average checkpoints 21.83, 57.9/28.7/16.3/9.7
transformer_wmt_en_zh_big 21.80, 57.8/28.7/16.3/9.7
+ average checkpoints 22.02, 57.3/28.4/16.2/9.7
transformer_wmt_en_zh_sparse_topk8 21.45, 58.0/28.5/16.2/9.6
+ average checkpoints 21.62, 58.4/28.9/16.5/9.8
transformer_wmt_en_zh_sparse_topk8_big 21.68, 57.5/28.4/16.1/9.6
+ average checkpoints 21.78, 57.4/28.4/16.1/9.6
transformer_wmt_en_zh_prime 22.37, 57.7/28.9/16.6/10.0
+ average checkpoints 22.74, 57.7/29.1/16.8/10.1

About

Machine translation experiments for WMT18 en-zh track

License:MIT License


Languages

Language:Python 92.6%Language:Shell 7.4%