Anwarvic / BertPuncCap

This is a simple PoC (proof-of-concept) model built to restore punctuation & capitalization from a given text. In other words, given a text with no punctuations and no capitalization, this model is able to restore the needed punctuations and capitalization to make the text human-readable.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BertPuncCap

This is a simple PoC (proof-of-concept) model built to restore punctuation & capitalization from a given text. In other words, given a text with no punctuations and no capitalization, this model is able to restore the needed punctuations and capitalization to make the text human-readable.

BertPuncCap is PyTorch model built on top of a pre-trained Google's BERT model by creating two linear layers that perform the two tasks simultaneously. One layer is responsible for the re-punctuation task and the other is responsible for the re-punctuation task as shown in the following figure:

How this model works can be summarized in the following steps:

  • BertPuncCap takes an input sentence that consists of segment_size=32 (by default) tokens. If the input is shorter than segment_size, then we are going to pad it with the edges (both ends of the input sentence). The segment_size is a hyper-parameter that you can tune.
  • Then, the pre-trained BERT language model will return the representations for the input tokens. The shape of the output should be segment_size x model_dim. If you are using BERT-base, then the model_dim=768.
  • These representations will be sent to the two linear layers for classification. One layer should classify the punctuation after each token while the other should classify the case.
  • The loss function will be the weighted sum of the punctuation classification loss punc-lossand the capitalization classification loss cap-loss according to the following formula where $\alpha$ is a hyper-parameter that you can set in your config.yaml file:

Note:

BertPuncCap was inspired by BertPunc with the following differences:

  • BertPunc only handles punctuation restoration task, while this model handles both punctuation restoration & re-capitalization.
  • BertPunc only handles COMMA, PERIOD and QUESTION_MARK, while this model handles three more punctuations EXCLAMATION, COLON, SEMICOLON. And you can add yours if you want. It's totally configurable.
  • BertPunc is not compatible with HuggingFace transformers package, while this model does.
  • BertPunc doesn't provide any pre-trained model, while this model provides many.

Working Example

You can check this notebook for the different ways for which you can use this model; also for how to get the confusion matrix of different classes.

Prerequisites

To install the dependencies, run the following command:

pip install -r requirements.txt

Pre-trained Models

You can download the pre-trained models from the following table:

Name Pre-trained BertPuncCap Training Data Pre-trained BERT Supported Languages
mbert_base_cased_fr ( Model, Configuration ) mTEDx bert-base-multilingual-cased French (fr)
mbert_base_cased_8langs ( Model, Configuration ) mTEDx bert-base-multilingual-cased
  • Arabic (ar)
  • German (de)
  • Greek (el)
  • French (fr)
  • Italian (it)
  • Spanish (es)
  • Portuguese (pt)
  • Russian (ru)

Now, it's very easy to use these pre-trained models; here is an example:

>>> from transformers import BertTokenizer, BertModel
>>> from model import BertPuncCap
>>> 
>>> # load pre-trained mBERT from HuggingFace's transformers package
>>> BERT_name = "bert-base-multilingual-cased"
>>> bert_tokenizer = BertTokenizer.from_pretrained(BERT_name)
>>> bert_model = BertModel.from_pretrained(BERT_name)
>>> 
>>> # load trained checkpoint
>>> checkpoint_path = os.path.join("models", "mbert_base_cased")
>>> bert_punc_cap = BertPuncCap(bert_model, bert_tokenizer, checkpoint_path)

Now that we have loaded the model, let's use it:

>>> x = ["bonsoir",
...      "notre planète est recouverte à 70 % d'océan et pourtant étrangement on a choisi de l'appeler « la Terre »"
... ]
>>> # start predicting
>>> bert_punc_cap.predict(x)
[
    'Bonsoir ,',
    "Notre planète est recouverte à 70 % d ' océan . et pourtant étrangement , on a choisi de l ' appeler « La Terre »"
]

Train

To train the model, you need to use the train.py script. Here is how you can do so:

python train.py --seed 1234 \
                --pretrained_bert bert-base-multilingual-cased \
                --optimizer Adam \
                --criterion cross_entropy \
                --alpha 0.5 \
                --dataset mTEDx \
                --langs fr \
                --save_path ./models/mbert_base_cased \
                --batch_size 1024 \
                --segment_size 32 \
                --dropout 0.3 \
                --lr 0.00001 \
                --max_epochs 50 \
                --num_validations 1 \
                --patience 1 \
                --stop_metric overall_f1

Training Parameters

The following is a full list of all training parameters that can be used with this model:

Parameter Description Possible Values Default
seed Random seed Any positive integer value 1234
pretrained_bert The name of the pre-trained BERT model from huggingface
  • bert-base-cased
  • bert-base-uncased
  • bert-base-multilingual-cased
  • bert-base-multilingual-uncased
  • camembert-base
  • flaubert/flaubert_base_cased
bert-base-multilingual-cased
optimizer The optimizer name to train this model on Adam -
lr The learning rate used by the optimizer. Any positive number. 0.00001
criterion The criterion used to train the model cross_entropy -
alpha The tuning parameter of punc_loss & cap_loss any value that belongs to [0,1] 0.5
dataset The dataset used for training mTEDx mTEDx
langs List of languages from the dataset that you need to train your model on. Depends on the dataset fr
save_path The relative/absolute path to save the model. A working path 1234-
batch_size The batch size for training, validating, and testing. Any positive integer value 256
segment_size The segment size of the model. Any positive integer value 32
dropout The dropout rate of the linear layers buit on top of BERT. Any value between 0 and 1. 0.3
max_epochs The maximum number of epochs to train the model. Any positive integer value 50
num_validations The number of validations to perform per epoch. Any positive integer value 1
patience The number of validations to wait for performance improvement before early stopping. Any positive integer value 10
stop_metric The name of the metric to watch for monitor to measure peformance for early stopping
  • valid_loss
  • punc_valid_loss
  • case_valid_loss
  • punc_overall_f1
  • case_overall_f1
  • overall_f1
overall_f1

Punctuations & Cases

The list of punctuations & cases handled by this model can be seen down below:

  • Punctuations:

    • COMMA
    • PERIOD
    • QUESTION
    • EXCLAMATION
    • COLON
    • SEMICOLON
    • O
  • Cases:

    • F (First_Cap): When the first letter is capital.
    • A (All_Cap): When the whole token is capitalized.
    • O: Other

Progress Tracking

The training progress will be written in a file called progress.tsv which can be used to monitor the model's performance during training. In this file, you can find important metrics about the training process.

For example, the following is the training/validation loss:

And the following is the F1-scores of all punctuation classes punc_overall_f1, all capitalization classes case_overall_f1, and all of the classes overall_f1:

Inference

You can use this model to re-punctuate & re-capitalize ASR transcription. You can use you repunc_recap.py python script to do so, given the path of

  • Pre-trained BertPuncCap.
  • ASR output transcription file.
  • Output file.

The following is a working example:

python repunc_recap.py \
    --ckpt /gfs/project/stag/users/manwar/BertPuncCap/models/mbert_base_cased_8langs \
    --in /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test_fr.hyp \
    --out /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test_punc_cap_fr.hyp

After running this code, a new file named test_punc_cap_fr.hyp will be created where it should have words that are punctuated and capitalized.

Benchmarking

For benchmarking this model and evaluating how it performs, you can use the benchmark.py python script to do so. It works similar to the previous script where you need the absolute/relative path of:

  • Pre-trained BertPuncCap.
  • ASR reference file.

NOTE:
This reference file should have words that re punctuated (have punctuations) & capitalized.

The following is a working example:

python benchmark.py \
    --ckpt /gfs/project/stag/users/manwar/BertPuncCap/models/mbert_base_cased_8langs \
    --in /gfs/project/stag/users/manwar/results/mTEDx_4/CASCADE/XLSR/test.ref

And the following is the output which shows the Precision, Recall, and F1 scores of the different punctuations and cases:

                            ,         .         ?    !         :    ;
Precision  0.974879  0.562257  0.536195  0.617647  0.0  0.687500  0.0
Recall     0.960403  0.604181  0.681283  0.253012  0.0  0.488889  0.0
F1         0.967587  0.582466  0.600094  0.358974  0.0  0.571429  0.0

                  O         F         A
Precision  0.962127  0.655914  0.495283
Recall     0.970909  0.559238  0.664557
F1         0.966498  0.603730  0.567568

About

This is a simple PoC (proof-of-concept) model built to restore punctuation & capitalization from a given text. In other words, given a text with no punctuations and no capitalization, this model is able to restore the needed punctuations and capitalization to make the text human-readable.

License:Apache License 2.0


Languages

Language:Python 99.5%Language:Shell 0.5%