prajdabre / yanmtt

Yet Another Neural Machine Translation Toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

mBART embedding matrix prunning while finetuning on a single language

GorkaUrbizu opened this issue · comments

Finetuning mBART-large on my machines is posible with gradient accumulation, but the training could be faster if I was able to decrease the size of the model loaded.

Is there any easy way to reduce the size of the vocabulary of mBART, prunning embedding parameters we won't use when finetuning the LM on a monolingual task using your tool?

If prunning embeddings is not posible with your toolkit, during continued pretraining or finetuning, I found this thread, which could work to trim the vocab used and reduce the model size, during finetuning. But I am not sure yet if it can be applied directly to the hugginface mBART models.

Hi Gorka,

Sorry for the late reply. It's a national holiday here.

When it comes to embedding matrix pruning, my code is not explicitly designed for it.
I have a few solutions for you:

  1. The mbart implementation has a mechanism called attention head pruning where unnecessary attention heads are removed and that helps reduce model sizes prior to fine tuning.
  2. Alternatively you can consider doing PCA of the embedding matrix (which is shared between the encoder and decoder and also the softmax linear layer) and then reduce it to a lower dimensionality. Save a new checkpoint with this reduced embedding and use this new checkpoint for fine tuning. However you will have to insert a new projection layer between the reduced embedding matrix and the first layer of the encoder/decoder. Alternatively you will have to remove rows and columns from the weight matrices of the layers. This is complicated for sure.
  3. Prior to fine tuning you can first save the mbart model locally, then delete some layers from the checkpoint, change the config file to reflect the number of remaining layers and then save a new checkout. Use this for fine tuning.

I also had to use gradient accumulation with mbart large on my machines. So there's no workaround other than using multiple GPUs. You may consider finding a way to do distillation of the mbart model but that's a nightmare.

Hi Raj,
and thanks for the fast response. Feel free to not reply me on weekends, holidays or out of office hours in the future.

I was considering this approach just to accelerate the process of finetuning, but luckily I don't have OOM issues with small batch sizes, even if it takes a bit longer. I will check the solutions you propose and see if I want to spend time on that, or if I will keep using the whole model for now.

Thanks again, your help is really appreciated.

Gorka