stelladk / PretrainingBERT

Pre-training BERT masked language models with custom vocabulary

Home Page:https://aclanthology.org/2021.nllp-1.9/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pre-training BERT Masked Language Models (MLM)

This repository contains the method to pre-train a BERT model using custom vocabulary. It was used to pre-train JuriBERT presented in the relevant paper.

It also contains the code of the classification task that was used to evaluate JuriBERT.

Our models can be found at the French Language Portal and downloaded upon request.

Instructions

To pre-train a new BERT model you need the path to a dataset containing raw text. You can also specify an existing tokenizer for the model. Paths for saving the model and the checkpoints are required.

python pretrain.py \
      --files /path/to/text \
      --model_path /path/to/save/model \
      --checkpoint /path/to/save/checkpoints \
      --epochs 30 \
      --hidden_layers 2 \
      --hidden_size 128 \
      --attention_heads 2 \
      --save_steps 10 \
      --save_limit 0 \
      --min_freq 0

To finetune on a classification task you need the path to the pre-trained model and a CSV file containing the classification dataset. You need to specify the columns containing the category and the text as well as the path for saving the final model and the checkpoints.

python classification.py \
  --model "custom" \
  --pretrained_path /path/to/model.bin \
  --tokenizer_path /path/to/tokenizer.json \
  --data /path/to/data.csv \
  --category "category-column" \
  --text "text-column" \
  --model_path /path/to/save/model \
  --checkpoint /path/to/save/checkpoints 

You can use --help to see all the available commands.

To test the masked language model use:

fill_mask = pipeline(
    "fill-mask",
    model="/path/to/model",
    tokenizer=tokenizer
)

fill_mask("Paris est la capitale de la <mask>.")

About

Pre-training BERT masked language models with custom vocabulary

https://aclanthology.org/2021.nllp-1.9/


Languages

Language:Python 100.0%