chengchingwen / Transformers.jl

Julia Implementation of Transformer models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BERT CoLA Example Question

jackn11 opened this issue · comments

I am trying to train several BERT classifier models in one program but i am running out of GPU RAM by loading too many BERT models with const _bert_model, wordpiece, tokenizer = pretrain"Bert-uncased_L-12_H-768_A-12"

I am following the CoLA exmaple found here https://github.com/chengchingwen/Transformers.jl/blob/master/example/BERT/cola/train.jl

I am wondering if the train!() function found in the example trains all of the parameters shown in Flux.params(bert_model) or if it only trains those found in Flux.params(_bert_model.classifier). The reason why this is important is, if only the classifier parameters are modified instead of all bert model parameters, then I can load one pretrain"Bert-uncased_L-12_H-768_A-12" into RAM, instead of many, and then just train new classifiers ( _bert_model.classifier) for each bert classifier I need. This saves a lot of RAM of not loading in a new full BERT model for each bert classifier needed.

Please let me know if the whole bert model is trained with the train!() function or just the classifier parameters.

Thank you,

Jack

commented

The whole model is trained, according to const ps = params(bert_model). You can switch to params(bert_model.classifier) if you only want to train the classifier. The different between _bert_model and bert_model is that _bert_model is on cpu while bert_model is on GPU. There're only 1 bert model been loaded to the GPU. The problem are probably because your GPU RAM is not big enough for a batch with size 4.

Then whether you can train the classifier layer only is a design problem of your own model. That is totally doable, but might result in different model performance (for better or worse). So that's up to you.