kermitt2 / delft

a Deep Learning Framework for Text

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Classification and transformers

lfoppiano opened this issue · comments

I open this in a separate issue (from #131 (comment)), it seems that there is something wrong when using the classification applications with transformers.

If I run something like:

python -m delft.applications.citationClassifier train_eval --fold-count=2 --transformer allenai/scibert_scivocab_cased/dir
loading citation sentiment corpus...

------------------------ fold 0--------------------------------------
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 150, 300)]   0           []                               
                                                                                                  
 bidirectional (Bidirectional)  (None, 150, 128)     140544      ['input_1[0][0]']                
                                                                                                  
 dropout (Dropout)              (None, 150, 128)     0           ['bidirectional[0][0]']          
                                                                                                  
 bidirectional_1 (Bidirectional  (None, 150, 128)    74496       ['dropout[0][0]']                
 )                                                                                                
                                                                                                  
 global_max_pooling1d (GlobalMa  (None, 128)         0           ['bidirectional_1[0][0]']        
 xPooling1D)                                                                                      
                                                                                                  
 global_average_pooling1d (Glob  (None, 128)         0           ['bidirectional_1[0][0]']        
 alAveragePooling1D)                                                                              
                                                                                                  
 concatenate (Concatenate)      (None, 256)          0           ['global_max_pooling1d[0][0]',   
                                                                  'global_average_pooling1d[0][0]'
                                                                 ]                                
                                                                                                  
 dense (Dense)                  (None, 32)           8224        ['concatenate[0][0]']            
                                                                                                  
 dense_1 (Dense)                (None, 3)            99          ['dense[0][0]']                  
                                                                                                  
==================================================================================================
Total params: 223,363
Trainable params: 223,363
Non-trainable params: 0
__________________________________________________________________________________________________
Traceback (most recent call last):
  File "/Users/lfoppiano/opt/anaconda3/envs/delft2/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/lfoppiano/opt/anaconda3/envs/delft2/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/lfoppiano/development/projects/delft/delft/applications/citationClassifier.py", line 159, in <module>
    y_test = train_and_eval(embeddings_name, args.fold_count, architecture=architecture, transformer=transformer)    
  File "/Users/lfoppiano/development/projects/delft/delft/applications/citationClassifier.py", line 75, in train_and_eval
    model.train_nfold(x_train, y_train)
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/wrapper.py", line 176, in train_nfold
    self.models = train_folds(x_train, y_train, self.model_config, self.training_config, self.embeddings,
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/models.py", line 302, in train_folds
    foldModel.train_model(model_config.list_classes, training_config.batch_size, max_epoch, use_roc_auc, 
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/models.py", line 153, in train_model
    y_pred = self.model.predict(
  File "/Users/lfoppiano/opt/anaconda3/envs/delft2/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/data_generator.py", line 46, in __getitem__
    batch_x, batch_y = self.__data_generation(index)
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/data_generator.py", line 80, in __data_generation
    input_ids, input_masks, input_segments = create_batch_input_bert(self.x[(index*self.batch_size):(index*self.batch_size)+max_iter], 
  File "/Users/lfoppiano/development/projects/delft/delft/textClassification/preprocess.py", line 63, in create_batch_input_bert
    encoded_tokens = transformer_tokenizer.batch_encode_plus(texts, add_special_tokens=True, truncation=True, 
AttributeError: 'NoneType' object has no attribute 'batch_encode_plus'

it seems that the tokenizer is not initialised correctly. What am I missing?

In the current applications/citationClassifier.py, you need to indicate the architecture when using a transformer layer:

python3 delft/applications/citationClassifier.py train_eval --architecture bert --transformer allenai/scibert_scivocab_cased

$ python3 delft/applications/citationClassifier.py train_eval --architecture bert --transformer allenai/scibert_scivocab_cased
loading citation sentiment corpus...
allenai/scibert_scivocab_cased will be used, loaded via huggingface
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_token (InputLayer)    [(None, None)]            0         
                                                                 
 tf_bert_model (TFBertModel)  TFBaseModelOutputWithPoo  109938432
                             lingAndCrossAttentions(l            
                             ast_hidden_state=(None,             
                             None, 768),                         
                              pooler_output=(None, 76            
                             8),                                 
                              past_key_values=None, h            
                             idden_states=None, atten            
                             tions=None, cross_attent            
                             ions=None)                          
                                                                 
 dropout_37 (Dropout)        (None, 768)               0         
                                                                 
 dense (Dense)               (None, 3)                 2307      
                                                                 
=================================================================
Total params: 109,940,739
Trainable params: 109,940,739
Non-trainable params: 0
_________________________________________________________________
Epoch 1/3
 11/245 [>.............................] - ETA: 2:08 - loss: 1.6049 - accuracy: 0.2926

OK, shouldn't we just then make the architecture a required parameter?

Update: Actually we should have architecture == bert if we want to use the transformers, right?

There was an if that ended up inverted so bert_data was True when it should have been False. Long story short, we can close this 😄