earthspecies / aves

AVES: Animal Vocalization Encoder based on Self-Supervision

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AVES: Animal Vocalization Encoder based on Self-Supervision

AVES (Animal Vocalization Encoder based on Self-Supervision) is a self-supervised, transformer-based audio representation model for encoding animal vocalizations ("BERT for animals"). It is based on HuBERT (Hsu et al., 2021), a powerful self-supervised model for human speech, but pretrained on large-scale unannotated audio datasets (FSD50K, AudioSet, and VGGSound) which include animal sounds.

Comprehensive experiments with a suite of classification and detection tasks (from the BEANS benchmark) have shown that AVES outperforms all the strong baselines and even the supervised "topline" models trained on annotated audio classification datasets.

See our paper for more details.

How to use AVES

Create a conda environment by running, for example:

conda create -n aves python=3.8 pytorch cudatoolkit=11.3 torchvision torchaudio cudnn -c pytorch -c conda-forge

Create your working directory:

mkdir aves

Or simply clone this repository:

git clone https://github.com/earthspecies/aves.git

AVES is based on HuBERT, which is implemented in fairseq, a sequence modeling toolkit developed by Meta AI. Check out the specific commit of fairseq which AVES is based on, and install it via pip. Please note that you might encounter import issues if you install fairseq directly under your working directory. In the code below, we demonstrate installation under a sibling directory."

git clone https://github.com/facebookresearch/fairseq.git
cd fairseq
git checkout eda70379
pip install --editable ./
cd ../aves

Download the pretrained weights. See the table below for the details. We recommend the AVES-bio configuration, as it was the best performing model overall in our paper.

wget https://storage.googleapis.com/esp-public-files/aves/aves-base-bio.pt

You can load the model via the fairseq.checkpoint_utils.load_model_ensemble_and_task() method. You can implement a PyTorch classifier which uses AVES as follows. See test_aves.py for a working example of an AVES-based classifier. Note that AVES takes raw waveforms as input.

class AvesClassifier(nn.Module):
    def __init__(self, model_path, num_classes, embeddings_dim=768, multi_label=False):

        super().__init__()

        models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_path])
        self.model = models[0]
        self.model.feature_extractor.requires_grad_(False)
        self.head = nn.Linear(in_features=embeddings_dim, out_features=num_classes)

        if multi_label:
            self.loss_func = nn.BCEWithLogitsLoss()
        else:
            self.loss_func = nn.CrossEntropyLoss()

    def forward(self, x, y=None):
        out = self.model.extract_features(x)[0]
        out = out.mean(dim=1)             # mean pooling
        logits = self.head(out)

        loss = None
        if y is not None:
            loss = self.loss_func(logits, y)

        return loss, logits

Ported versions

The original model uses Fairseq models. We have ported the models to TorchAudio models and Onnx formats.

TorchAudio

Download both the parameters and the model config under TorchAudio version in Pretrained models.

from torchaudio.models import wav2vec2_model

class AvesTorchaudioWrapper(nn.Module):

    def __init__(self, config_path, model_path):

        super().__init__()

        # reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html

        self.config = self.load_config(config_path)
        self.model = wav2vec2_model(**self.config, aux_num_out=None)
        self.model.load_state_dict(torch.load(model_path))
        self.model.feature_extractor.requires_grad_(False)

    def load_config(self, config_path):
        with open(config_path, 'r') as ff:
            obj = json.load(ff)

        return obj

    def forward(self, sig):
        # extract_feature in the sorchaudio version will output all 12 layers' output, -1 to select the final one
        out = self.model.extract_features(sig)[0][-1]

        return out

torchaudio_model = AvesTorchaudioWrapper(config_path, model_path)
torchaudio_model.eval()

Onnx

Download the parameters and the model config under Onnx version in Pretrained models. NOTE: We observed that the Onnx version of AVES-all could have large relative differences compared to the original version when the output values are close to zero. The TorchAudio versions don't have this problem.

    import onnxruntime

    ort_session = onnxruntime.InferenceSession(model_path)
    ort_inputs = {ort_session.get_inputs()[0].name: sig}
    ort_outs = ort_session.run(None, ort_inputs)
    onnx_out = ort_outs[0]

Pretrained models

Configuration Pretraining data Hours Link to pretrained weights TorchAudio version Onnx version
AVES-core FSD50k + AS (core) 153 Download Download Config Download
AVES-bio core + AS/VS (animal) 360 Download Download Config Download
AVES-nonbio core + AS/VS (non-animal) 360 Download Download Config Download
AVES-all core + AS/VS (all) 5054 Download Download Config Download

Colab Notebooks

About

AVES: Animal Vocalization Encoder based on Self-Supervision

License:MIT License


Languages

Language:Python 100.0%