This repository contains solution of NER task based on PyTorch reimplementation of Google's TensorFlow repository for the BERT model that was released together with the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular Google's pre-trained models) and a conversion script is provided (see below).
1. Loading a TensorFlow checkpoint (e.g. Google's pre-trained models)
You can convert any TensorFlow checkpoint for BERT (in particular the pre-trained models released by Google) in a PyTorch save file by using the convert_tf_checkpoint_to_pytorch.py
script.
This script takes as input a TensorFlow checkpoint (three files starting with bert_model.ckpt
) and the associated configuration file (bert_config.json
), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using torch.load()
.
You only need to run this conversion script once to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with bert_model.ckpt
) but be sure to keep the configuration file (bert_config.json
) and the vocabulary file (vocab.txt
) as these are needed for the PyTorch model too.
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (pip install tensorflow
). The rest of the repository only requires PyTorch.
Here is an example of the conversion process for a pre-trained BERT-Base Uncased
model:
export BERT_BASE_DIR=/path/to/bert/multilingual_L-12_H-768_A-12
python3 convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
You can download Google's pre-trained models for the conversion here.
There is used the BERT-Base, Multilingual and BERT-Cased, Multilingual (recommended) in this solution.
We didn't search best parametres and obtained the following results for no more than 10 epochs.
Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook |
---|---|---|---|---|---|
FactRuEval | ru | 0.937 | 0.883 | 4 | factrueval.ipynb |
Atis | en | 0.852 | 0.787 | 65 | conll-2003.ipynb |
Conll-2003 | en | 0.945 | 0.858 | 5 | atis.ipynb |
- Factrueval (f1): 0.9163±0.006, best 0.926.
- Atis (f1): 0.882±0.02, best 0.896
- Conll-2003 (f1, dev): 0.949±0.002, best 0.951; 0.892 (f1, test).
Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook |
---|---|---|---|---|---|
FactRuEval | ru | 0.925 | 0.827 | 4 | factrueval-nmt.ipynb |
Atis | en | 0.919 | 0.829 | 65 | atis-nmt.ipynb |
Conll-2003 | en | 0.936 | 0.900 | 5 | conll-2003-nmt.ipynb |
Dataset | Lang | IOB precision | Span precision | Clf precision | Total spans in test set | Total classes | Notebook |
---|---|---|---|---|---|---|---|
Atis | en | 0.877 | 0.824 | 0.894 | 65 | 17 | atis-joint.ipynb |
Dataset | Lang | IOB precision | Span precision | Clf precision | Total spans in test set | Total classes | Notebook |
---|---|---|---|---|---|---|---|
Atis | en | 0.913 | 0.820 | 0.888 | 65 | 17 | atis-joint-nmt.ipynb |
We tested BertBiLSTMCRF
, BertBiLSTMAttnCRF
and BertBiLSTMAttnNMT
on russian dataset FactRuEval with freezed ElmoEmbedder
:
Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook |
---|---|---|---|---|---|
FactRuEval | ru | 0.903 | 0.851 | 4 | samples.ipynb |
Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook |
---|---|---|---|---|---|
FactRuEval | ru | 0.899 | 0.819 | 4 | factrueval.ipynb |
Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook |
---|---|---|---|---|---|
FactRuEval | ru | 0.902 | 0.752 | 4 | factrueval-nmt.ipynb |
This code was tested on Python 3.5+. The requirements are:
- PyTorch (>= 0.4.1)
- tqdm
- tensorflow (for convertion)
To install the dependencies:
pip install -r ./requirements.txt
All models are organized as Encoder
-Decoder
. Encoder
is a freezed and weighted (as proposed in elmo) bert output from 12 layers. There are three models that is obtained by using different Decoder
.
Encoder
: BertBiLSTM
BertBiLSTMCRF
:Encoder
+Decoder
(BiLSTM + CRF)BertBiLSTMAttnCRF
:Encoder
+Decoder
(BiLSTM + MultiHead Attention + CRF)BertBiLSTMAttnNMT
:Encoder
+Decoder
(LSTM + Bahdanau Attention - NMT Decode)BertBiLSTMAttnCRFJoint
:Encoder
+Decoder
(BiLSTM + MultiHead Attention + CRF) + (PoolingLinearClassifier - for classification) - joint model with classification.BertBiLSTMAttnNMTJoint
:Encoder
+Decoder
(LSTM + Bahdanau Attention - NMT Decode) + (LinearClassifier - for classification) - joint model with classification.
from modules.bert_data import BertNerData as NerData
data = NerData.create(train_path, valid_path, vocab_file)
from modules.bert_models import BertBiLSTMCRF
model = BertBiLSTMCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)
from modules.train import NerLearner
learner = NerLearner(model, data, best_model_path="/datadrive/models/factrueval/exp_final.cpt", lr=0.01, clip=1.0, sup_labels=data.id2label[5:], t_total=num_epochs * len(data.train_dl))
learner.fit(2, target_metric='prec')
from modules.data.bert_data import get_bert_data_loader_for_predict
dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner)
learner.load_model(best_model_path)
preds = learner.predict(dl)
- For more detailed instructions of using BERT model see samples.ipynb.
- For more detailed instructions of using ELMo model see samples.ipynb.