tracel-ai / models

Models and examples built with Burn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Custom BERT Model outputs

ashdtu opened this issue · comments

For flexibility to fine-tune in downstream tasks, we should have the following options in the BERT family model outputs:

  1. Return output of a pooling layer on top of [CLS] token embedding via a user configurable flag.(add_pooling_layer=True) like transformers API.
  2. Return the full last hidden states of encoder layer : [Batch, Sequence, Embedding Dim] instead of just the [CLS] token embedding through a user configurable flag.

I have a quick first pass at point 1 in a fork, based on how rust-bert handles it: https://github.com/bkonkle/burn-models/blob/707153f5ef1f1f2e8478cebf45e3ca58247d8348/bert-burn/src/model.rs#L168-L171

How would I approach point 2?

I'm making more progress on the very beginnings of a transformers-style library for Burn using traits for pipeline implementations, but in my WIP testing so far I'm having trouble with learning not working correctly. It doesn't seem to be using the pre-trained weights form bert-base-uncased correctly, so accuracy fluctuates around 25% to 50%.

https://github.com/bkonkle/burn-transformers

This is using my branch with pooled Bert output. The branch doesn't currently build, but I plan to do more work on it this week to fix that and get a good example in place for feedback.

Awesome @bkonkle! I think the current implementation is using RoBERTa weights instead of BERT, so maybe this isn't compatible with the BERT weights for the classification head. Not sure if this helps, but if you find something not working, make sure to test multiple backends and report a bug if there are differences.

Okay, I believe I understand goal 2 better now. I was thinking this meant a flag for all_hidden_states, like this flag in Huggingface's transformers library. I now believe that this means just the full Tensor from the last hidden_states value, like this property in Huggingface's transformers. This would correspond with the final x value in Burn's Transformer Encoder, here.

If my interpretation is correct, I believe the approach in my fork here addresses this by returning both the last hidden states and the optional pooled output if available.

Update: Solved - see the next comment below.

Previous troubleshooting details

From what I can tell, the Bert model here should support `bert-base-uncased` without any issues using the additional pooler layer (which is also loaded from the safetensors file) despite being originally written for RoBERTa. Unfortunately, I'm still getting really poor accuracy when loading from safetensors and then fine-tuning on the Snips dataset.

The training loop I use is defined here, in my early-stage port of transformers in the text-classification pipeline.

======================== Learner Summary ========================
Model: Model[num_params=109489161]
Total Epochs: 10


| Split | Metric        | Min.     | Epoch    | Max.     | Epoch    |
|-------|---------------|----------|----------|----------|----------|
| Train | Loss          | 2.004    | 10       | 2.269    | 1        |
| Train | Learning Rate | 0.000    | 10       | 0.000    | 1        |
| Train | Accuracy      | 12.050   | 1        | 31.490   | 10       |
| Valid | Loss          | 1.955    | 10       | 2.205    | 1        |
| Valid | Accuracy      | 18.000   | 1        | 40.000   | 10       |

The learning rate seems like a clue - it shouldn't be zero, right?

By comparison, the training implementation for the Snips dataset in the JointBERT repo (which just combines text and token classification into one) hits accuracy in the 90% range within the first few epochs.

I'll probably move on to token classification for now and come back to text classification once I get some feedback and try to figure out what's going wrong. 😅

Thank you for any guidance you can provide!

The learning rate was indeed a hint. I had it set way too low, based on the default value in the JointBERT repo I was learning from. 😅 Setting the learning rate to 1e-2 solves my problem, so I think my branch is ready for some review to see if this is the right approach to enabling custom BERT model outputs. 👍

======================== Learner Summary ========================
Model: Model[num_params=109489161]
Total Epochs: 10

Split Metric Min. Epoch Max. Epoch
Train Loss 0.003 10 0.246 1
Train Accuracy 92.540 1 99.900 10
Train Learning Rate 0.000 10 0.000 1
Valid Loss 0.060 10 0.109 8
Valid Accuracy 96.900 8 98.100 10