swagshaw / ATST-SED

This repo includes the official implementations of "Fine-tune the pretrained ATST model for sound event detection".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ATST-SED

The official implementations of "Fine-tune the pretrained ATST model for sound event detection" (accepted by ICASSP 2024).

This work is highly related to ATST, ATST-Frame. Please check these works if you want to find out the principles of the ATST-SED.

PWC | License: MIT

Paper ๐Ÿคฉ | Issues ๐Ÿ˜… | Lab ๐Ÿ™‰ | Contact ๐Ÿ˜˜

Introduction

In a nutshell, ATST-SED proposes a fine-tuning strategy for the pretrained model integrated with the CRNN SED system.

The proposed fine-tuning method for ATST-SED

Updating Notice

  • Real dataset download: The 7000+ strongly-labelled audio clips extracted from the AudioSet is provided in this issue.

  • Strong val dataset: This dataset meta files are now updated to the repo.

  • About batch sizes: If you change the batch sizes when fine-tuning ATST-Frame (Stage 1/2), you might probably need to change the n_epochs and n_epochs_warmup in the configuration file train/local/confs/stage2.yaml correspondingly. The fine-tuning of ATST-SED is related to the batch sizes, you might not reproduce the reported results when using a smaller batch sizes. The ablation study of the batch size setups is shown in the model performance below.

Comparing with DCASE code

To allow the SED community better understands the codes and implementation details, we developed the algorithm based on the baseline codes of DCASE2023 challenge task 4. Namely, the training progress is build under pytorch-lightning.

we changed

The other parts in the desed_task are left unchange

Get started

  1. To reproduce our experiments, please first ensure you have the full DESED dataset (including 3000+ strongly labelled real audio clips from the AudioSet).

  2. Ensure you have the correct environment. The environment of this code is the same as the DCASE 2023 baseline, please refer to their docs/codes to configure your environment.

  3. Download the pretrained ATST checkpoint (atst_as2M.ckpt). Noted that this checkpoint is fine-tuned by the AudioSet-2M.

  4. Clone the ATST-SED codes by:

git clone https://github.com/Audio-WestlakeU/ATST-SED.git
  1. Install our desed_task package by:
cd ATST-SED
pip install -e .
  1. Change all required paths in train/local/confs/stage1.yaml and train/local/confs/stage2.yaml to your own paths. Noted that the pretrained ATST checkpoint path should be changed in both files.

  2. Start training stage 1 by:

python train_stage1.py --gpus YOUR_DEVICE_ID,

We also supply a pretrained stage 1 ckpt for you to fine-tune directly. Stage_1.ckpt. If you cannot run stage 1 without accm_grad=1, we recommend you to use this checkpoint first.

  1. When finishing the stage 1 training, change the path of the model_init in train/local/confs/stage2.yaml to the stage 1 checkpoint path (we saved top-5 models in both stages of training, you could use the best one as the model initialization in the stage 2, but use any one of the top-5 models should give the similar results).

  2. Start training stage 2 by:

python train_stage2.py --gpus YOUR_DEVICE_ID,

Performance

We report both DESED development set and public evaluation set results. The external set is the extra data extracted from the AudioSet/AudioSetStrong. Please do not mess it with the 3000+ strongly labelled real audio clips from the AudioSet.

Please note that ATST-SED also get top-ranked performance on the public evaluation dataset without using external dataset. But we did not report it in our paper since the limited writing space. Top-1 model used extra weakly-labelled data from AudioSet, we are still mining these part of the data to improve the model performance.

Dataset External set PSDS_1 PSDS_2 ckpt
DCASE dev. set - 0.583 0.810 Stage2_wo_ext.ckpt
DCASE public eval. set - 0.631 0.833 same as the above
DCASE dev. set Used 0.587 0.812 Stage2_w_ext.ckpt
DCASE public eval. set Used 0.631 0.846 same as the above

Two fine-tuned ATST-SED checkpoints are also released in the table below. You can download them and use them directly.

If you want to check the performance of the fine-tuned checkpoint:

python train_stage2.py --gpus YOUR_DEVICE_ID, --test_from_checkpoint YOUR_CHECKPOINT_PATH

Ablation on batch sizes:

We report the model performances on the development set with the following setups:

Batch sizes n_epochs n_epochs_warmup accm_grad PSDS_1 PSDS_2
[4, 4, 8, 8] 40 2 \ 0.535 0.784
[8, 8, 16, 16] 80 2 \ 0.562 0.802
[12, 12, 24, 24] 125 5 \ 0.570 0.805
[4, 4, 8, 8] 250 10 6 0.579 0.811

As shown in the table, if you cannot afford the default batch sizes, please make sure that they are in a proper level. Or, we recommend you to use accm_grad hyperparameter in the stage2.yaml to enlarge the batch sizes. However, using accm_grad would also decay the model performances, due to its influcences to the batch norm layer of the CNN model. Comparing with the reported results, you might get a poorer result from 56%~58% in PSDS1 (using last ckpt for validation).

Citation

If you want to cite this paper:

@INPROCEEDINGS{10446159,
  author={Shao, Nian and Li, Xian and Li, Xiaofei},
  booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Fine-Tune the Pretrained ATST Model for Sound Event Detection}, 
  year={2024},
  volume={},
  number={},
  pages={911-915},
  keywords={Training;Event detection;Self-supervised learning;Feature extraction;Transformers;Task analysis;Speech processing;sound event detection;self-supervised learning;ATST;fine-tuning pretrained model},
  doi={10.1109/ICASSP48485.2024.10446159}}

About

This repo includes the official implementations of "Fine-tune the pretrained ATST model for sound event detection".

License:MIT License


Languages

Language:Jupyter Notebook 66.9%Language:Python 33.1%