JL-er / RWKV_LM_EXT

This project is to extend RWKV LM's capabilities including sequence classification/embedding/peft/cross encoder/bi encoder/multi modalities, etc.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RWKV_LM_EXT

This project is to extend RWKV LM's capabilities including sequence classification/embedding/peft/cross encoder/bi encoder/multi modalities, etc.

src/model_ext.py

We extends two types of model based on RWKV(5,6)。

  • RwkvForClassification

This class is used to do sequence classification.

graph LR
    A(idx) --> B[embeddings]
    B --> C[Apply rwkv blocks]
    C --> D[Found the eos id's embeddings]
    D --> E[Score the embeddings]
    E --> F(Return scores)
Loading
  • RwkvForSequenceEmbedding

This class is used to do sequence embedding.

graph LR
    A(idx) --> B[embeddings]
    B --> C[Apply rwkv blocks]
    C --> D[Apply pooling method]
    D --> E(Return embeddings)
Loading

Some lora checkpoints:

Try python peft_train/peft_test.py to see how both 2 sft work seemlessly. In the future more sft parameters can work together to build a more sophiscated AI assistant.

Further more some utilities:

  • peft_train/hf2rwkv_lora.py convert the lora check point trained by huggingface peft to rwkv lora format. This utility can be used to TURN BiEncoder checkpoint.

  • peft_train/peft_test.py is a better example to show how to use one base model to do variable stuffs in runtime by switching lora adapters only.

    . if you're using PISSA checkpoint, please set --chat_lora_alpha 8 instead --chat_lora_alpha 32 in Lora.

Training process to do a SFT using RWKV's lora/pissa/state tuning.

python data/SftUtilities.py --input_dir JSON_DIR --output_dir DATASET_DIR --tokenizer_file PATH_TO_rwkv_vocab_v20230424.txt
  • We got datasets in DATASET_DIR, now we can train our SFT model. The following's are the script we use to train lora/pissa/state tuning. Let's assume we output our SFT peft model to OUTPUT_DIR. For a 4090ti, we can train the data with batch 32 16 8 4 2 1 corresponding to 64 128 256 512 1024 2048 max length.

    • Lora
    RWKV_TRAIN_TYPE=lora python peft_train/peft_train_sft.py --train_data /home/rwkv/data/tigerbot_sft_dataset --model_file MODEL_DIR/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth --output_dir OUTPUT_DIR/lora_rwkv --train_type lora --log_every_n_steps 100000 --target_modules att ffn --log_every_n_steps 100000 --wandb lora_tigerbots --num_epochs 3 --train_lengths 64 128 256 512 1024 2048 --train_batch_sizes 32 16 8 4 2 1
    • Pissa
    RWKV_TRAIN_TYPE=pissa python peft_train/peft_train_sft.py --train_data /home/rwkv/data/tigerbot_sft_dataset --model_file MODEL_DIR/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth --output_dir OUTPUT_DIR/pissa_rwkv --train_type pissa --log_every_n_steps 100000 --target_modules att ffn --log_every_n_steps 100000 --wandb pissa_tigerbots --num_epochs 3 --train_lengths 64 128 256 512 1024 2048 --train_batch_sizes 32 16 8 4 2 1
    

Beam search with logits processors

Now user can use beam search to get variable results with beam search. Try src/tests/TestBeamSearch.py.

截图 2024-05-24 10-22-40.png

Encoders for inference

Please refer infer/encoders to utilize the pure RWKV's mulitple lora adapters. 截图 2024-05-24 17-43-40.png

About

This project is to extend RWKV LM's capabilities including sequence classification/embedding/peft/cross encoder/bi encoder/multi modalities, etc.


Languages

Language:Python 94.0%Language:Cuda 4.2%Language:C++ 1.4%Language:Jupyter Notebook 0.4%