airaria / GRAIN

GRAIN: Gradient-based Intra-attention Pruning on Pre-trained Language Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GRAIN

This is the repo of paper Gradient-based Intra-attention Pruning on Pre-trained Language Models accepted to ACL 2023.

The repo is under construction.

overview

The workflow of GRAIN

Usage

Step 1: Preparation

  1. Prepare the teacher models. We provide the teacher models of GLUE (MNLI, QNLI, QQP and SST2) and SQuAD tasks, which can be downloaded on Google Drive. Unzip the teacher_models.zip, The content of teacher_models should be
teacher_models\
    mnli\
      pytorch_model.bin
    qnli\
      pytorch_model.bin
    qqp\
      pytorch_model.bin
    sst2\
      pytorch_model.bin
    squad\
      pytorch_model.bin
    config.json
    vocab.txt
  1. Prepare the GLUE and SQuAD datasets. Put the datasets to datasets.

Step 2: Training/Distillation with Pruning

We offer examples of training on GLUE and SQuAD.

GLUE

cd scripts
bash run_glue.sh

Change the TASK to one of sst2|mnli|qnli|qqp to run different tasks.

SQuAD

cd scripts
bash run_squad.sh

Post Pruning

The model obtained in the above step are store with full parameters and pruning masks. We then then perform post-pruning operation to remove the weights from the model.

Run the PostPruning.ipynb and follow the steps there to remove the redundant weights and test the inference speed of the pruned model.

About

GRAIN: Gradient-based Intra-attention Pruning on Pre-trained Language Models

License:Apache License 2.0


Languages

Language:Python 96.1%Language:Jupyter Notebook 3.1%Language:Shell 0.9%