This repository contains MLM fine-tuning task and contextual data regeneration examples using the implementation of BERT from:
The training dataset comes from:
Disclaimer: The dataset for this competition contains text that may be considered profane, vulgar, or offensive.
Thanks to "Conditional BERT Contextual Augmentation" https://arxiv.org/pdf/1812.06705.pdf for providing such a wonderful idea.
Section | Description |
---|---|
Requirements | How to install the required package |
Usage | Quickstart examples |
Modification | How to train data from other sources |
Effect | The effect of augmentation using Kaggle Baseline model |
GPU | GPU requirement and memory |
This repo was tested on Python 3.6 and PyTorch 1.1
PyTorch pretrained bert can be installed by pip as follows:
pip install pytorch-pretrained-bert
PyTorch can be installed by conda as follows:
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
If you want to reproduce the results of toxic comment augmentation, you can run the command:
python finetune.py
python train_aug.py
If you want to use this repo for other text data augmentation, there are some tips of modification:
- Add your own taskname under the ojcet of AugProcessor
- If the format of your data is not .csv file, you need to modify the reading method under the object of DataProcessor
And you can run the test with the command:
python finetune.py \
--data_dir $INPUT_DIR/$TASK_NAME \
--output_dir $OUTPUT_DIR/$TASK_NAME \
--task_name $TASK_NAME \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 5e-5 \
--num_train_epochs 10.0 \
python train_aug.py \
--data_dir $INPUT_DIR/$TASK_NAME \
--output_dir $OUTPUT_DIR/$TASK_NAME \
--task_name $TASK_NAME \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 5e-5 \
--num_train_epochs 10.0 \
The final result is tested on the Kaggle Baseline model
If you want to reproduce our results with the defult settings, you need a GPU with more than 14GB memory. Otherwise you need to decrease the number of batch_size and max_seq_length.