yuulin / knowledge_pretrain

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Introduction

K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation.

Installation

git clone https://github.com/pytorch/fairseq.git
cd fairseq 
git checkout 49940c8d25d61a251e290d96fe3bbbc9f210408f
pip install --editable ./

Pre-training

export CUDA_VISIBLE_DEVICES=0,1,2,3

function join_by { local IFS="$1"; shift; echo "$*"; }
DATA_DIR=$(join_by : data/knowledge_pretrain/bin/part*)

USER_DIR=src
TOKENS_PER_SAMPLE=512
WARMUP_UPDATES=10000
PEAK_LR=0.0005
TOTAL_UPDATES=125000
#MAX_SENTENCES=8
MAX_SENTENCES=16
UPDATE_FREQ=16   # batch_size=update_freq*max_sentences*nGPU = 16*16*4 = 1024

SUB_TASK=mlm_clm_sentcls_segcls_titlegen 
## ablation task
#SUB_TASK=clm_sentcls_segcls_titlegen
#SUB_TASK=mlm_sentcls_segcls_titlegen
#SUB_TASK=mlm_clm_sentcls_segcls
#SUB_TASK=mlm_clm_segcls_titlegen
#SUB_TASK=mlm_clm_sentcls_titlegen

fairseq-train $DATA_DIR \
    --user-dir $USER_DIR \
    --task multitask_lm \
    --sub-task $SUB_TASK \
    --arch transformer_pretrain_base \
    --min-loss-scale=0.000001 \
    --sample-break-mode none \
    --tokens-per-sample $TOKENS_PER_SAMPLE \
    --criterion multitask_lm \
    --apply-bert-init \
    --max-source-positions 512 --max-target-positions 512 \
    --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
    --lr-scheduler polynomial_decay --lr $PEAK_LR \
    --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
    --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
    --max-sentences $MAX_SENTENCES --update-freq $UPDATE_FREQ \
    --ddp-backend=no_c10d \
    --tensorboard-logdir tensorboard \
    --classification-head-name pretrain_head --num-classes 40 \
    --tagging-head-name pretrain_tag_head --tag-num-classes 2 \
    --fp16

Fine-tuning and Inference

Finetuning on JDDC (Response Generation)

Finetuning on ECD Corpus (Response Retrieval)

Finetuning on JD Product Dataset (Abstractive Summarization)

Finetuning on MEPAVE Dataset (Sequence Tagging)

About


Languages

Language:Python 100.0%