wellecks / mgs

MLE-Guided Parameter Search (AAAI 2021)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MLE-Guided Parameter Search (MGS)

PyTorch implementation of the paper:

MLE-Guided Parameter Search for Task Loss Minimization in Neural Sequence Modeling
Sean Welleck, Kyunghyun Cho
AAAI 2021

Main Logic

For a quick overview of MGS's main logic, see this section from its training loop.

Installation

python setup.py develop

Data

For downloading the datasets below, it may be helpful to use gdown.pl.

Pretrained Models

We provide an example base MLE model and example models finetuned with MGS, PG, and MRT.
Note that metrics in the paper were computed using 5 models per method, each initialized with a different random seed.

Method
MLE
MGS-LM
MGS-LM (ancestral)
PG-LM
MRT-LM

Example commands

Below we show example commands for each stage of the pipeline.
The experiments in the paper were run with a script external to this repository.

Finetune starting from MLE finetune

# MGS
python seq_level/gpt2/train.py \
  --loss ggs \
  --ggs-metric lm \
  --ggs-beta 1.0 \
  --model-load-dir /path/to/mle_model

# PG
python seq_level/gpt2/train.py \
  --loss pg \
  --ggs-metric lm \
  --pg-normalize-distance 1 \
  --pg-mle-mix 0.1 \
  --pg-baseline avg \
  --model-load-dir /path/to/mle_model
  
# MRT
python seq_level/gpt2/train.py \
  --loss mrt \
  --ggs-metrc lm \
  --mrt-normalize-distance 1 \
  --mrt-mle-mix 0.1 \
  --model-load-dir /path/to/mle_model

Finetune MLE

python seq_level/gpt2/train.py \
  --loss mle \
  --valid-every 5000 \
  --print-every 100

Evaluate

python seq_level/gpt2/train.py --mode eval \
  --eval-split valid \ # | test
  --score-model-load-dir /path/to/mle_model \
  --model-load-dir /path/to/model \
  --eval-decoder greedy \ # | temp-1.0
  --token-limit-eval 500 \
  --eval-decode-max-length 500 \
  --chunk-size-valid 512 \
  --loss ggs \
  --ggs-metric lm \

Preprocess raw wikitext

*not needed if you download the dataset above

python seq_level/gpt2/prepare_wikitext.py --data-dir /path/to/wikitext-raw

About

MLE-Guided Parameter Search (AAAI 2021)


Languages

Language:Python 100.0%