kermitt2 / delft

a Deep Learning Framework for Text

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sub-tokenization with certain transformers

lfoppiano opened this issue · comments

@pjox and I are working on a model trained with Roberta and using the BPE tokenizer, in particular zeldarose which uses slightly different special tokens.

We have some problem when the data is tokenized.
In particular, the sub-tokenisation from the tokenizers somehow get messed up when is_split_into_words=True and with the library transformers of version 4.15.0 (tokenizers library version 0.10.3):

The code here (preprocess.py:304):

# sub-tokenization
encoded_result = self.tokenizer(text_tokens, add_special_tokens=True, is_split_into_words=True,
            max_length=max_seq_length, truncation=True, return_offsets_mapping=True)

text_tokens = ['We', 'are', 'studying', 'the', 'material', 'La', '3', 'A', '2', 'Ge', '2', '(', 'A', '=', 'Ir', ',', 'Rh', ')', '.', 'The', 'critical', 'temperature', 'T', 'C', '=', '4', '.', '7', 'K', 'discovered', 'for', 'La', '3', 'Ir', '2', 'Ge', '2', 'in', 'this', 'work', 'is', 'by', 'about', '1', '.', '2', 'K', 'higher', 'than', 'that', 'found', 'for', 'La', '3', 'Rh', '2', 'Ge', '2', '.']

the output offsets are as follows: [(0, 0), (0, 2), (1, 3), (1, 8), (1, 3), (1, 8), (1, 2), (1, 1), (1, 1), (1, 1), (1, 2), (1, 1), (1, 1), (1, 1), (1, 1), (1, 2), (1, 1), (1, 2), (1, 1), (1, 1), (1, 3), (1, 8), (1, 11), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 10), (1, 3), (1, 2), (1, 1), (1, 2), (1, 1), (1, 2), (1, 1), (1, 2), (1, 4), (1, 4), (1, 2), (1, 2), (1, 5), (1, 1), (1, 1), (1, 1), (1, 1), (1, 6), (1, 4), (1, 4), (1, 5), (1, 3), (1, 2), (1, 1), (1, 2), (1, 1), (1, 2), (1, 1), (1, 1), (0, 0)]

the first two items are correct, from the third, the sequence get messed up, the third should be (0, 3), then (0,8), etc... and this get wrongly reconstructed by the delft code after that. If the pair does not starts the code is unclear, I don't understand why adding <PAD>:

                else:
                    # propagate the data to the new sub-token or 
                    # dummy/empty input for sub-tokens
                    label_ids.append("<PAD>")
                    chars_blocks.append(self.empty_char_vector)
                    # 2 possibilities, either empty features for sub-tokens or repeating the 
                    # feature vector of the prefix sub-token 
                    #feature_blocks.append(self.empty_features_vector)
                    feature_blocks.append(features_tokens[word_idx])

if I pass the string and set is_split_into_words=False:

self.tokenizer("".join(text_tokens), add_special_tokens=True, is_split_into_words=False,             max_length=max_seq_length, truncation=True, return_offsets_mapping=True)

I obtain the correct result: [(0, 0), (0, 2), (2, 7), (7, 10), (10, 13), (13, 14), (14, 17), (17, 24), (24, 26), (26, 27), (27, 28), (28, 29), (29, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 37), (37, 38), (38, 40), (40, 42), (42, 45), (45, 53), (53, 64), (64, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 74), (74, 81), (81, 84), (84, 86), (86, 87), (87, 89), (89, 90), (90, 92), (92, 93), (93, 95), (95, 99), (99, 103), (103, 105), (105, 107), (107, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 122), (122, 124), (124, 128), (128, 130), (130, 135), (135, 138), (138, 140), (140, 141), (141, 143), (143, 144), (144, 146), (146, 147), (147, 148), (0, 0)]

The option is_split_into_words was though only for split by space, which is not the case for most of our use cases.

Here there is an explanation but I did not understand it well: huggingface/transformers#8217
(in any case it works only with the python tokenizers)

Probably, we should consider

self.tokenizer(text_tokens, add_special_tokens=True, is_split_into_words=False, max_length=max_seq_length, truncation=True, return_offsets_mapping=True)

which will return a list of list, for each token:

[
    [(0, 0), (0, 2), (0, 0)], 
    [(0, 0), (0, 3), (0, 0)], 
    [(0, 0), (0, 8), (0, 0)], 
    [...]
]

and then, with some additional works, we should be able to reconstruct the output correctly.

I've also find that updating the transformers library to 4.25.1 solves the problem on my M1 Mac, but open to new problems on Linux.

Hello !

Afaik Roberta and its BPE tokenizer are working well in my test with version transformers 4.25.1, but I think not anymore with version 4.15.0 (but it used to work also with this version at some point in the past :). I changed versions in 389eb3d but only in setup.py... I forgot requirements.txt sorry.

4.25.1 changed the behavior of the BPE tokenization, in a good way I think. I try to explain how it works below and how I added the support of Roberta-style tokenizers.

We start with a pretokenized input. According to the Tokenizer library doc: "If the sequences are provided as list of strings (pretokenized), you must set is_split_into_words=True (to lift the ambiguity with a batch of sequences)."

The difference with a "traditional" BERT tokenizer is that in this pretokenized input case, the BPE tokenizer has no leading space to perform a proper "tokenization" - note that here (to be clear) it's a sub-tokenization (we tokenize the tokens...). So the developer of Tokenizer introduced a trick to add a leading space to every tokens when "sub-tokenizing" - the is_split_into_words indicates that the input is pretokenized and add_prefix_space=True that this trick should be used in case of pretokenization (add_prefix_space=True is passed when initializing the AutoTokenizer).

I open a parenthesis here:

This create a problem for BPE with the first token, which has no space before. The trick will add a leading space before the tokens so that the subtokenization is similar to the one with a complete sentence string including spaces. This trick remains in general I think an approximation because we can have tokens without space prefix of course depending on the pre-tokenization pipeline.
But the main problem is at the first token with the Tokenizer implementation. For these tokenizers we have an extra space also before the first token - this is always wrong... When looking at the source code, we find this in the comments:

When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). 

I don't know why they add a space for the token immediately after <s>.

Closing of the parenthesis!

The resulting tokenization is a list of subtokens - whose offsets in the case of BPE does NOT refer to the original string (which is not inputted), but to each token.

So what is returned with is_split_into_words=True is actually "correct" for version 4.15.0. The offsets are not relative to the whole sentence, but to each current input token - this is not just BPE, this offset relative to tokens are obtained in general when is_split_into_words=True for non-BPE tokenizers too.

To illustrate, I change a bit the example for something more complicated:

original input (pre-tokenized): ['We', 'are', 'studying', 'the', 'retokenization']
roberta-base token ids: [0, 166, 32, 7739, 5, 5494, 22036, 1938, 2]
offsets: [(0, 0), (0, 2), (1, 3), (1, 8), (1, 3), (1, 3), (3, 7), (7, 14), (0, 0)]
self.tokenizer.convert_ids_to_tokens(encoded_result.input_ids): 
['<s>', 'ĠWe', 'Ġare', 'Ġstudying', 'Ġthe', 'Ġret', 'oken', 'ization', '</s>']

The offset are all relative to each token: each token is an input which starts at 0 (special tokens and first token in the case of Roberta at least) or 1 (following tokens, because of the fake added space for the above-mentioned BPE trick - the exact encoding symbol for this space depends on the transformer implementation).

But this offset behavior changed with version 4.25.1, because it was confusing to be honest, and we have now:

original input (pre-tokenized): ['We', 'are', 'studying', 'the', 'retokenization']
roberta-base token ids: [0, 166, 32, 7739, 5, 5494, 22036, 1938, 2]
offsets: [(0, 0), (0, 2), (0, 3), (0, 8), (0, 3), (0, 3), (3, 7), (7, 14), (0, 0)]
self.tokenizer.convert_ids_to_tokens(encoded_result.input_ids): 
['<s>', 'ĠWe', 'Ġare', 'Ġstudying', 'Ġthe', 'Ġret', 'oken', 'ization', '</s>']

Note the offsets now all starting at 0 when we have a starting space Ġ.

But good because now it is the same as a BERT tokenizer for instance with is_split_into_words=True. But note in both cases the ĠWe, while we would want the input token id to be the one corresponding to We (cf. parenthesis above!).

Should we use rather the is_split_into_words=False and a concatenated input as you suggest?

self.tokenizer("".join(text_tokens), add_special_tokens=True, is_split_into_words=False, max_length=max_seq_length, truncation=True, return_offsets_mapping=True)

First the "".join(text_tokens) you used cannot work because we would have no space at all,
and I guess you meant " ".join(text_tokens).
This is then very similar to the existing BPE trick, but would solve the first token problem. However by doing this, we loose the initial pretokenization input that we need to restore later because we still need to align with labels and features (which are also pretokenized as the token input). I tried it at some point, but it was adding complication and some source of errors when re-aligning other pretokenized "channels". This is also not relevant for "usual" BERT tokenization, only for BPE tokenization. But it could work I suppose with more motivation :)

Normally the is_split_into_words=True is intended for the pre-tokenized input we have and I am also worry to introduce some more hack to something which is already quite hacky.

I tested with current DeLFT version and Transformers transformers==4.25.1, normally no issue:

> python3 delft/applications/nerTagger.py train_eval --dataset-type conll2003 --architecture BERT 
--transformer roberta-base

Loading CoNLL 2003 data...
14041 train sequences
	 nb. tokens 203621
	 with nb. entities 34043
3250 validation sequences
	 nb. tokens 51362
	 with nb. entities 8603
3453 evaluation sequences
	 nb. tokens 46435
	 with nb. entities 8112


20744 total sequences
301418 total tokens

total distinct characters: 85 

BERT
roberta-base will be used, loaded via huggingface
---
max_epoch: 50
early_stop: True
batch_size (training): 32
max_sequence_length: 150
model_name: ner-en-conll2003-BERT
learning_rate:  0.001
use_ELMo:  False

...

439/439 [==============================] - ETA: 0s - loss: 6.5545e-04   f1 (micro): 95.92
439/439 [==============================] - 274s 625ms/step - loss: 6.5545e-04 - f1: 0.9592
training runtime: 9372.786 seconds 

Evaluation on test set:
---
max_epoch: 50
early_stop: True
batch_size (training): 32
max_sequence_length: 150
model_name: ner-en-conll2003-BERT
learning_rate:  0.001
use_ELMo:  False

...

__________________________________________________________________________________________________
                  precision    recall  f1-score   support

             LOC     0.9327    0.9382    0.9354      1668
            MISC     0.8030    0.8362    0.8193       702
             ORG     0.8913    0.9235    0.9072      1661
             PER     0.9673    0.9703    0.9688      1617

all (micro avg.)     0.9136    0.9304    0.9219      5648

Great result for a base model by the way !

> python3 delft/applications/nerTagger.py --dataset-type conll2003 --file-in data/test/test.ner.en.txt 
--architecture BERT --transformer roberta-base tag

{
    "software": "DeLFT",
    "date": "2023-01-13T20:05:33.170468",
    "model": "ner-en-conll2003-BERT",
    "texts": [
        {
            "text": "The University of California has found that 40 percent of its students suffer food insecurity. At four state universities in Illinois, that number is 35 percent.",
            "entities": [
                {
                    "text": "University of California",
                    "class": "ORG",
                    "score": 0.9999942779541016,
                    "beginOffset": 4,
                    "endOffset": 27
                },
                {
                    "text": "Illinois",
                    "class": "LOC",
                    "score": 0.9999986886978149,
                    "beginOffset": 125,
                    "endOffset": 132
                }
            ]
        },
        {
            "text": "President Obama is not speaking anymore from the White House.",
            "entities": [
                {
                    "text": "Obama",
                    "class": "PER",
                    "score": 0.9999897480010986,
                    "beginOffset": 10,
                    "endOffset": 14
                },
                {
                    "text": "White House",
                    "class": "LOC",
                    "score": 0.9999994039535522,
                    "beginOffset": 49,
                    "endOffset": 59
                }
            ]
        },
        {
            "text": "This",
            "entities": []
        },
        {
            "text": "is",
            "entities": []
        },
        {
            "text": "a",
            "entities": []
        }
    ],
    "runtime": 1.307
}

If the pair does not starts the code is unclear, I don't understand why adding <PAD>

We put an empty label to the subtokens added by the tokenizers. This is working better than repeating the label of the token to all its subtokens and this is the original implementation of BERT ("We use the representation of the first sub-token as the input to the token-level classifier over the NER label set.")

See also https://datascience.stackexchange.com/questions/69640/what-should-be-the-labels-for-subword-tokens-in-bert-for-ner-task/75225#75225

Should we use rather the is_split_into_words=False and a concatenated input as you suggest?

bla bla

But it could work I suppose with more motivation :)

Actually I just realize this is exactly the purpose of issue #128 :D

Thanks @kermitt2 for the extended answer.

Indeed, I did not think of checking the setup.py 😭 .
With the library 4.25.1 this issue is actually not an issue anymore, my idea was just to update the library, but then I had another issue (see below).

The roberta-based model I'm referring is the following:

https://kdrive.infomaniak.com/app/share/104844/41053dc4-5398-4841-939d-c67583de96d6

With the previous version there were no issues but the evaluation results were very low, indicating tokenization problems.
However, after updating to 4.25.1, I have another issue on the tensorflow side, maybe you have already something similar:

(delft_tf27) [lfoppian0@mdpfdc005 delft_tf2_transformers]$ python -m delft.applications.grobidTagger matbert-pedro-scicorpus-20000-vocab_100k-1 train --input /lustre/group/tdm/Luca/sampling/superconductors-220630-positive_sampling.train  --architecture BERT_CRF --max-sequence-length 512 --transformer portiz/matbert-pedro-scicorpus-20000-vocab_100k/dir --batch-size 10
Loading data...
16902 total sequences
15211 train sequences
1691 validation sequences

max train sequence length: 212
max validation sequence length: 168
BERT_CRF
portiz/matbert-pedro-scicorpus-20000-vocab_100k/dir will be used, loaded via local_model_dir
---
max_epoch: 60
early_stop: False
batch_size (training): 10
max_sequence_length: 512
model_name: grobid-matbert-pedro-scicorpus-20000-vocab_100k-1-BERT_CRF
learning_rate:  0.001
use_ELMo:  False
---
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_token (InputLayer)       [(None, None)]       0           []                               
                                                                                                  
 input_attention_mask (InputLay  [(None, None)]      0           []                               
 er)                                                                                              
                                                                                                  
 input_token_type (InputLayer)  [(None, None)]       0           []                               
                                                                                                  
 tf_roberta_model (TFRobertaMod  TFBaseModelOutputWi  162842112  ['input_token[0][0]',            
 el)                            thPoolingAndCrossAt               'input_attention_mask[0][0]',   
                                tentions(last_hidde               'input_token_type[0][0]']       
                                n_state=(None, None                                               
                                , 768),                                                           
                                 pooler_output=(Non                                               
                                e, 768),                                                          
                                 past_key_values=No                                               
                                ne, hidden_states=N                                               
                                one, attentions=Non                                               
                                e, cross_attentions                                               
                                =None)                                                            
                                                                                                  
 dropout_37 (Dropout)           (None, None, 768)    0           ['tf_roberta_model[0][0]']       
                                                                                                  
==================================================================================================
Total params: 162,842,112
Trainable params: 162,842,112
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "crf_model_wrapper_for_bert"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 crf (CRF)                   multiple                  10990     
                                                                 
 model (Functional)          (None, None, 768)         162842112 
                                                                 
=================================================================
Total params: 162,853,102
Trainable params: 162,853,102
Non-trainable params: 0
_________________________________________________________________
Epoch 1/60
   5/1691 [..............................] - ETA: 26:18 - loss: 1550.8781 - crf_loss: 1550.8781Traceback (most recent call last):
  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/applications/grobidTagger.py", line 411, in <module>
    train(model, 
  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/applications/grobidTagger.py", line 187, in train
    model.train(x_train, y_train, f_train, x_valid, y_valid, f_valid, incremental=incremental)
  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/wrapper.py", line 187, in train
    trainer.train(x_train, y_train, x_valid, y_valid, features_train=f_train, features_valid=f_valid, callbacks=callbacks)
  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/trainer.py", line 59, in train
    self.model = self.train_model(self.model, x_train, y_train, x_valid=x_valid, y_valid=y_valid,
  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/trainer.py", line 175, in train_model
    local_model.fit(training_generator,
  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnknownError: Graph execution error:

2 root error(s) found.
  (0) UNKNOWN:  IndexError: list index out of range
Traceback (most recent call last):

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 831, in wrapped_generator
    for data in generator_fn():

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 957, in generator_fn
    yield x[i]

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/data_generator.py", line 254, in __getitem__
    batch_x, batch_x_types, batch_x_masks, batch_c, batch_f, batch_l, batch_input_offsets, batch_y = self.__data_generation(index)

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/data_generator.py", line 334, in __data_generation
    input_ids, token_type_ids, attention_mask, input_chars, input_features, input_labels, input_offsets = self.bert_preprocessor.tokenize_and_align_features_and_labels(

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/preprocess.py", line 260, in tokenize_and_align_features_and_labels
    input_ids, token_type_ids, attention_mask, chars_block, feature_blocks, target_tags, tokens = self.convert_single_text(text,

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/preprocess.py", line 351, in convert_single_text
    label_ids.append(label_tokens[word_idx])

IndexError: list index out of range


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]]
	 [[IteratorGetNext/_34]]
  (1) UNKNOWN:  IndexError: list index out of range
Traceback (most recent call last):

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 831, in wrapped_generator
    for data in generator_fn():

  File "/home/lfoppian0/anaconda3/envs/delft_tf27/lib/python3.8/site-packages/keras/engine/data_adapter.py", line 957, in generator_fn
    yield x[i]

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/data_generator.py", line 254, in __getitem__
    batch_x, batch_x_types, batch_x_masks, batch_c, batch_f, batch_l, batch_input_offsets, batch_y = self.__data_generation(index)

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/data_generator.py", line 334, in __data_generation
    input_ids, token_type_ids, attention_mask, input_chars, input_features, input_labels, input_offsets = self.bert_preprocessor.tokenize_and_align_features_and_labels(

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/preprocess.py", line 260, in tokenize_and_align_features_and_labels
    input_ids, token_type_ids, attention_mask, chars_block, feature_blocks, target_tags, tokens = self.convert_single_text(text,

  File "/lustre/group/tdm/Luca/delft/delft_tf2_transformers/delft/sequenceLabelling/preprocess.py", line 351, in convert_single_text
    label_ids.append(label_tokens[word_idx])

IndexError: list index out of range


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_34806]

I have this issue only on Linux 😭 and I'm using CUDA 11.2.

@lfoppiano This Roberta model raises encoding issues because its tokenizer is not loaded properly. I could reproduce the error with

(env) lopez@trainer:~/delft$ python3 delft/applications/grobidTagger.py citation train_eval  --architecture BERT_CRF  --transformer /media/lopez/T51/embeddings/matbert-pedro-scicorpus-20000-vocab_100k/ --input data/sequenceLabelling/grobid/citation/citation-231022.train

because the input text includes non-US characters.

The error is due to an encoding error of the created BPE tokenizer (RobertaTokenizer) which adds a wrong extra tokens and shift everything. For example:

input: 'Troadec, É.'
tokens: ['Troadec', ',', 'É', '.']
encoded_result.input_ids: [0, 35893, 9359, 3410, 16227, 185, 3096, 2]
encoded_result.offset_mapping: [(0, 0), (0, 3), (3, 7), (0, 1), (0, 1), (0, 1), (0, 1), (0, 0)]
self.tokenizer.convert_ids_to_tokens(encoded_result.input_ids): ['<s>', 'ĠTro', 'adec', 'Ġ,', 'ĠÃ', 'ī', 'Ġ.', '</s>']

"É" get encoded as 2 tokens "Ã" and "ī".

However in principle the right token exists in the vocab (matbert-pedro-scicorpus-20000-vocab_100k/tokenizer.json):

"ĠÉ": 70434,

From this, the decoded string is then:

['<s>', 'Troadec', ',', 'Ã', 'ī', '.', ['<s>']

So alignment is wrong, encoding is wrong, and this is not recoverable afaik.

From what I see, the source of the problem is that the local tokenizer file of this local Roberta model is not loaded (matbert-pedro-scicorpus-20000-vocab_100k/tokenizer.json). In the method get_model called by wrapper.py, local_path is not instantiated.

https://github.com/kermitt2/delft/blob/master/delft/sequenceLabelling/wrapper.py#L168

The transformer tokenizer will be initialized without local_path:

https://github.com/kermitt2/delft/blob/master/delft/sequenceLabelling/models.py#L253
https://github.com/kermitt2/delft/blob/master/delft/sequenceLabelling/models.py#L257

So what is initialized I think is a default RobertaTokenizer without vocabulary supporting the actual model/input, via HuggingFace only.

To load locally the local transformer tokenizer via the current method (utilities/Transformer.py) we would need either to define the local transformer path in the resource_registry.json file, or to introduce an argument in grobigTagger.py to explicitly say that the model is expressed as a path (and propagate the path in the wrapper.py), or maybe by checking the transformer name to see if it looks like a local path (sounds like "magic" :), so that the right loading method on utilities/Transformer.py is used (which should be LOADING_METHOD_LOCAL_MODEL_DIR I think ?).

By adding the model path info in the resource_registry.json file:

"transformers": [
        {
            "name": "matbert-pedro-scicorpus-20000-vocab_100k",
            "model_dir": "/media/lopez/T51/embeddings/matbert-pedro-scicorpus-20000-vocab_100k"
        }
    ],

I have the model loaded I think correctly with --transformer matbert-pedro-scicorpus-20000-vocab_100k as parameter:

(env) lopez@trainer:~/delft$ python3 delft/applications/grobidTagger.py citation train_eval  --architecture BERT_CRF  --transformer matbert-pedro-scicorpus-20000-vocab_100k --input data/sequenceLabelling/grobid/citation/citation-231022.train

But I still have the encoding error unfortunately. If "É" and "ĠÉ" are in the vocabulary, I don't understand why É is not parsed as one character over several bytes and parsed as 2 distinct tokens (both with offset (0, 1), (0, 1) - so no way to know it's one single token originally).

So to summarize there are still 2 problems:

  1. Apparently the vocabulary of the model matbert-pedro-scicorpus-20000-vocab_100k is not loaded/considered because "É" and "ĠÉ" are present in tokenizer.json but their encoding gives 2 "sub-bytes" tokens 'Ã', 'ī'

  2. In the pre-tokenize case, the added token (e.g. 'ī' as above) have the same offset as the first one ((0, 1), (0, 1)), so we can't re-align easily the tokens with the original tokens and tokenized labels/features...
    Thus issue was mentioned here
    This problem does not appear with a complete string as input, because offsets are then relative to the input and we can see the overlapping tokens in this case.

I fixed normally the problem 2. above with #154
This will fix the error for out of vocabulary characters that can appear from time to time, also for other sentencepiece tokenizers.

But there's still the issue 1., problem with matbert-pedro-scicorpus-20000-vocab_100k: the model tokenizer is not initialized apparently from the tokenizer file - this can come from the way this model is saved.

In addition, to the first problem I think the tokenizer seems correctly loaded: self.tokenizer.vocab['É'] returns correctly 136,

However, if I replace É with ī, the result is still wrong (both with is_split_into_words True or False):

self.tokenizer.convert_ids_to_tokens(self.tokenizer(['Troadec', ',', 'ī', '.'], 
add_special_tokens=True,           
is_split_into_words=True,
     max_length=max_seq_length, truncation=True, return_offsets_mapping=True).data['input_ids'])

and the results is wrong with another wrong character: ['<s>', 'ĠTro', 'adec', 'Ġ,', 'ĠÄ', '«', 'Ġ.', '</s>']

@pjox any thoughts?

Indeed it seems correctly loaded, I also have:

self.tokenizer.vocab['É'] 136
self.tokenizer.vocab['ĠÉ'] 70434

But then still when tokenizing:

input: ['Troadec', ',', 'É', '.']
encode: [(0, 0), (0, 3), (3, 7), (0, 1), (0, 1), (0, 1), (0, 1)]
id_to_tokens: ['<s>', 'ĠTro', 'adec', 'Ġ,', 'ĠÃ', 'ī', 'Ġ.']
decode: <s> ĠTro adec Ġ, ĠÃ ī Ġ </s>

So ĠÉ, although in the vocab, is not properly encoded.

However, if I replace É with ī, the result is still wrong (both with is_split_into_words True or False):

Well here, this should be expected because Ġī is not in the vocabulary of this model, so it is encoded at lower byte level as fallback, and those two "sub-characters" are "correct" here.

Well here, this should be expected because Ġī is not in the vocabulary of this model, so it is encoded at lower byte level as fallback, and those two "sub-characters" are "correct" here.

OK. I think I get it now.

However, we have only ī in the vocabulary:

self.tokenizer.vocab['ī']=185

however if we change the list as ['Troadec', ',', 'aī', '.'] we should get the 185 character, but we don't:

'input_ids' = {list: 9} [0, 35893, 9359, 3410, 212, 131, 109, 3096, 2]
'attention_mask' = {list: 9} [1, 1, 1, 1, 1, 1, 1, 1, 1]
'offset_mapping' = {list: 9} [(0, 0), (0, 3), (3, 7), (1, 1), (1, 1), (1, 2), (1, 2), (1, 1), (0, 0)]

And the reconstructed tokens are :

['<s>', 'ĠTro', 'adec', 'Ġ,', 'Ġa', 'Ä', '«', 'Ġ.', '</s>']

Am I right to say that this is not correct?

If we pass ['Troadec', ',', 'aī', '.'] the token is Ġaī, which is also not in the vocabulary, so we go to the lower-level of byte to match the vocab, which gives "correctly" 'Ġa', 'Ä', '«' because ī becomes Ä« ...

So this is correct BPE afaik and the number of characters before encoding and after decoding is not something fixed with BPE, what is fixed is the total number of bytes.

Also note that in the merges of the tokenizer of the model matbert-pedro-scicorpus-20000-vocab_100k, we have line 170312:
"Ġ É"

which - if I have understood correctly- means to merge the sequence to "ĠÉ" based on frequency information... thus the "ĠÉ" in the vocab. However it's not taken into account by the tokenizer.

I've tested the "pedro" model using the branch of PR #154. There are two good news here:

  1. The PR mitigate synchronization issues which previously leads to invalid evaluation scores
  2. The PR also avoid the "list index out of range" stated in the first comment

for point 1, one of the 5 evaluation results were:

number of alignment issues with test set: 3904
to solve them consider increasing the maximum sequence input length of the model and retrain
                  precision    recall  f1-score   support

         <class>     0.0000    0.0000    0.0000       271
      <material>     0.0556    0.0012    0.0024      1648
     <me_method>     0.2273    0.0141    0.0265       355
      <pressure>     0.0000    0.0000    0.0000        41
            <tc>     0.0714    0.0016    0.0031       639
       <tcValue>     0.0000    0.0000    0.0000       157

all (micro avg.)     0.1013    0.0026    0.0050      3111

which become:

filter2: ?| 
filter3: velop
filter2: ?| 
filter3: ical
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: 008
filter2: ?| 
filter3: 008
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: erated
filter2: ?| 
filter3: ically
filter2: ?| 
filter3: uration
filter2: ¦
filter2: ?| 
filter3: ization
filter2: ?| 
filter3: 006
filter2: ?| 
filter3: 009
filter2: ?| 
filter3: ically
filter2: ?| 
filter3: rom
filter2: ?| 
filter3: ized
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: ization
filter2: ?| 
filter3: ulate
filter2: ?| 
filter3: ilar
filter2: ?| 
filter3: izes
filter2: ?| 
filter3: ically
filter2: ?| 
filter3: ical
filter2: ?| 
filter3: ity
filter2: ?| 
filter3: ilar
filter2: ?| 
filter3: ?~Kļ
filter2: ?| 
filter3: ?~Kļ
filter2: ?| 
filter3: ?~Kļ
filter2: ?| 
filter3: ?~Kļ
filter2: ?| 
filter3: ?~Kļ
filter2: ?| 
[....] (tons of more and more filters)
number of alignment issues with test set: 83
to solve them consider increasing the maximum sequence input length of the model and retrain
                  precision    recall  f1-score   support

         <class>     0.7045    0.6863    0.6953       271
      <material>     0.7726    0.8022    0.7871      1648
     <me_method>     0.5616    0.7324    0.6357       355
      <pressure>     0.4884    0.5122    0.5000        41
            <tc>     0.7345    0.8138    0.7721       639
       <tcValue>     0.7440    0.7962    0.7692       157

all (micro avg.)     0.7251    0.7824    0.7526      3111

Compared results here:

  Scibert matbert-pedro-scicorpus-20000-vocab_100k
<class> 72.74 69.56
<material> 78.84 78.05
<me_method> 66.75 64.79
<pressure> 42.04 44.85
<tc> 79.02 78.38
<tcValue> 78.74 74.96
All (micro avg.) 76.30 75.14

I think they make sense since this model was trained only on 20k iterations.

In any case I think we can merge this PR since the problem seems to be more related to this specific model, as you said, @kermitt2

Thanks @kermitt2 for all the useful insight on the BPE tokenizer! 😄

I have been looking into this problem with @lfoppiano for the last couple of weeks, but we cannot seem to find a solution/explanation for problem 1.

I looked a bit into the model matbert-pedro-scicorpus-20000-vocab_100k which is trained with Zeldarose. However the problem does not seem to come from Zeldarose itself, as at least for the tokenizer part, zeldarose is only appears to be a sort of front-end to the Hugging Face implementation of BPE as you can see here.

I will continue to look at as soon as I have more time, one clue might be in the flair library as I have used Zeldarose trained models with flair and never encountered a problem with the token alignment.

I don't want to bother them too much (as I know they are very busy these days), but I'm also tagging @LoicGrobol as they might have encountered this problem earlier (maybe in hopsparser).

Thanks for tagging, @pjox! Actually I'm not sure I understand everything here, is it an issue that only concerns the character offsets or something bigger?

Oh, thanks a lot @LoicGrobol for taking the time to comment here! @kermitt2 can correct me if I'm wrong, but I think the problem is more the product of encoding of some special characters even when they are in the vocabulary and specially when the input is pre-tokenized. I am almost sure this is coming from HF, but tagged you just in case. @kermitt2 found a solution last week in #154 (that I haven't been able to check) but apparently it was more of a hack to force realign the offsets after encountering certain characters from what @lfoppiano told me (please do correct me if I'm wrong).

One issue I had trouble with was that certain tokenizers (Flaubert does iirc, so possibly it's because of xlm) are skipping some characters altogether, leading to inconsistencies in combination with the input being split into words but I don't think that's your problem here. It it is, though I'm happy to dig in my archives and in any case I'm curious about your issue, maybe it'll help me avoid trouble later 👀

To try to clarify, the remaining problem is not the offsets (it was due to an update in the huggingface library), nor the re-alignement - there are indeed some tokens in a pre-tokenization input which are added with weird offsets and with different behavior from one BPE/model to another one, but it's easy to skip these tokens just looking at them (I tested Roberta models, CamemBERT, bart-base, albert-base-v2, and XLM model).

The problem is that the BPE tokenizer saved with the "Pedro" model (sorry to associate you to the issue Pedro :D) is not working as expected I think. To reproduce, the example is #150 (comment)

Basically the vocab contains tokens apparently well loaded:

self.tokenizer.vocab['É'] 136
self.tokenizer.vocab['ĠÉ'] 70434

with merges in the tokenizer as expected for these tokens. However when present in the text input sequence, the token 'ĠÉ` is not encoded as expected, it is encoded with 2 "sub-bytes" as fallback:

input: ['Troadec', ',', 'É', '.']
encode: [(0, 0), (0, 3), (3, 7), (0, 1), (0, 1), (0, 1), (0, 1)]
id_to_tokens: ['<s>', 'ĠTro', 'adec', 'Ġ,', 'ĠÃ', 'ī', 'Ġ.']
decode: <s> ĠTro adec Ġ, ĠÃ ī Ġ </s>

So there is something apparently going wrong in the BPE tokenizer as initialized from the saved tokenizer.

When looking at the matbert-pedro-scicorpus-20000-vocab_100k/tokenizer_config.json file, there's a bad looking path (special_token_map), but fixing it does not change the behavior of the tokenizer.