- 2022/12/1: We have finished the human evaluation and provided guides and annatation results in
./huamn_evaluation
.
- torch>=1.7
- transformers==4.8.1
- datasets==1.12.1
- nltk==3.7
- rouge-score
data
-> directories to store datasets and the data preprocessing codes.data_utils
-> codes for dataloader and evaluation metricsmodel_utils
-> generator model and ranker modeltrainer_utils
-> trainer and trainer configurationwarmup-generato
-> warming-up generatorwarmup-ranker
-> first training iteration of rankerrun_train.py
-> the main function to run JGR
Now JGR can be used in CNN/DailyMail, SAMSum, SquadQG and Personachat. Users should download the raw data of these datasets and run the codes in ./data
to preprocess them. The codes will preprocess and re-orgnize the data samples into json file for later steps. For more details, please check the README.md
in ./data
, it will give you the details of preprocessing each dataset.
To achieve a better performance of JGR, it's neccessary to pre-finetune the generator with MLE loss on the target training set. Follow the steps described in ./warmup-generator/README.md
and you will get a warm-uped generator for target dataset stored in ./warmup-generator/saves
.
Notes: Before the first ranker training iteration. Your should use the warm-uped generator to generator the candidates for the first ranker training iteration. Don't forget to execute 2. Generate candidates for warming-up ranker
in ./warmup-generator/README.md
.
As mentioned in the paper, in order to initialize the ranker with a more general and reasonable ranking function, we increase the number of training steps and add a certain number of warm-up steps at the first ranker training iteration. Here we use the fine-tuned generator to generator the candidates for the first ranker training iteration. Follow the steps described in ./warmup-ranker/README.md
and you will get a warm-uped generator for target dataset stored in ./warmup-ranker/saves
.
After obtaining the warm-uped generator and ranker, you can now turn to JGR-training. Taking cnndm as example, to train JGR, run:
EXPORT generator=warmup-generator/saves/bart-large-cnndm # the warm-uped generator
EXPORT ranker=warmup-ranker/saves/roberta-large-cnndm # the ranker after first iteration
EXPORT save_name=JGR-large-cnndm # the save name of model
python -m torch.distributed.launch --nproc_per_node 8 run_train.py --overwrite_output_dir \
--task_name sum --dataset_name cnndm \
--train_data_path data/cnndm \
--dev_data_path data/cnndm \
--test_data_path data/cnndm \
--load_tokenized_data False \
--evaluate_generator True \
--generator_num_cand_generated 8 --generator_num_cand_picked 8 \
--num_cand_generated 16 --num_cand_picked 3 --candidate_pick_strategy bottom \
--do_train True --do_eval True --do_predict True --prediction_loss_only False \
--per_device_train_batch_size 2 --per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--generator_learning_rate 5e-5 --reranker_learning_rate 1e-5 \
--num_train_epochs 3 \
--evaluation_strategy steps --eval_steps 1000 \
--logging_strategy steps --logging_steps 500 \
--save_strategy steps --save_steps 1000 --save_total_limit 20 \
--iteration_steps 1000 --iteration_reranker_steps 500 \
--load_best_model_at_end True \
--metric_for_best_model generator_eval_rouge1 --greater_is_better True \
--reranker_model_name_or_path $ranker \
--generator_model_name_or_path $generator \
--output_dir saves/$save_name \
--generator_max_source_length 1020 --reranker_max_source_length 400 --generator_max_target_length 109 --reranker_max_target_length 109 \
--cache_data \
--disable_tqdm False
The above instructions will store the trained generator and ranker in saves/JGR-large/cnndm/generator
and saves/JGR-large/cnndm/reranker
, respectively. For the JGR taining on other datasets, check run_train.sh
.
To evaluate the trained generator and ranker, run:
EXPORT generator=saves/JGR-large-cnndm/generator # the trained generator
EXPORT ranker=saves/JGR-large-cnndm/ranker # the trained iteration
EXPORT save_name=JGR-large-cnndm # the save name of model
python -m torch.distributed.launch --nproc_per_node 8 run_train.py --overwrite_output_dir \
--task_name sum --dataset_name cnndm \
--train_data_path data/cnndm \
--dev_data_path data/cnndm \
--test_data_path data/cnndm \
--load_tokenized_data False \
--generator_num_cand_generated 8 --generator_num_cand_picked 8 \
--num_cand_generated 16 --num_cand_picked 3 --candidate_pick_strategy bottom \
--do_predict True --prediction_loss_only False \
--per_device_eval_batch_size 4 \
--evaluation_strategy steps --eval_steps 1000 \
--logging_strategy steps --logging_steps 500 \
--save_strategy steps --save_steps 1000 --save_total_limit 20 \
--iteration_steps 1000 --iteration_reranker_steps 500 \
--load_best_model_at_end True \
--metric_for_best_model generator_eval_rouge1 --greater_is_better True \
--reranker_model_name_or_path $ranker \
--generator_model_name_or_path $generator \
--output_dir saves/$save_name \
--generator_max_source_length 1020 --reranker_max_source_length 400 --generator_max_target_length 109 --reranker_max_target_length 109 \
--cache_data \
--disable_tqdm False
We also provide our trained checkpoints for you:
Generator | Ranker | |
---|---|---|
CNNDM | generator.zip | reranker.zip |
CNNDM (initialized with BRIO) | generator.zip | reranker.zip |
SAMSum | generator.zip | reranker.zip |
SquadQG | generator.zip | reranker.zip |
Personachat | generator.zip | reranker.zip |
We also provide our codes for other compared methods in Compared_method
, including GAN, RL (with evaluation metrics as reward), and actor-critic. We only conduct experiment for these method in CNN/DailyMail.
The checkpoints for these methods are also provided:
Generator | Discriminator | |
---|---|---|
GAN | generator.zip | discriminator.zip |
RL | generator.zip | - |
Actor-critic | generator.zip | critic.zip |