AVES (Animal Vocalization Encoder based on Self-Supervision) is a self-supervised, transformer-based audio representation model for encoding animal vocalizations ("BERT for animals"). It is based on HuBERT (Hsu et al., 2021), a powerful self-supervised model for human speech, but pretrained on large-scale unannotated audio datasets (FSD50K, AudioSet, and VGGSound) which include animal sounds.
Comprehensive experiments with a suite of classification and detection tasks (from the BEANS benchmark) have shown that AVES outperforms all the strong baselines and even the supervised "topline" models trained on annotated audio classification datasets.
See our paper for more details.
Create a conda environment by running, for example:
conda create -n aves python=3.8 pytorch cudatoolkit=11.3 torchvision torchaudio cudnn -c pytorch -c conda-forge
Create your working directory:
mkdir aves
Or simply clone this repository:
git clone https://github.com/earthspecies/aves.git
AVES is based on HuBERT, which is implemented in fairseq, a sequence modeling toolkit developed by Meta AI. Check out the specific commit of fairseq which AVES is based on, and install it via pip
. Please note that you might encounter import issues if you install fairseq directly under your working directory. In the code below, we demonstrate installation under a sibling directory."
git clone https://github.com/facebookresearch/fairseq.git
cd fairseq
git checkout eda70379
pip install --editable ./
cd ../aves
Download the pretrained weights. See the table below for the details. We recommend the AVES-bio
configuration, as it was the best performing model overall in our paper.
wget https://storage.googleapis.com/esp-public-files/aves/aves-base-bio.pt
You can load the model via the fairseq.checkpoint_utils.load_model_ensemble_and_task()
method. You can implement a PyTorch classifier which uses AVES as follows. See test_aves.py for a working example of an AVES-based classifier. Note that AVES takes raw waveforms as input.
class AvesClassifier(nn.Module):
def __init__(self, model_path, num_classes, embeddings_dim=768, multi_label=False):
super().__init__()
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_path])
self.model = models[0]
self.model.feature_extractor.requires_grad_(False)
self.head = nn.Linear(in_features=embeddings_dim, out_features=num_classes)
if multi_label:
self.loss_func = nn.BCEWithLogitsLoss()
else:
self.loss_func = nn.CrossEntropyLoss()
def forward(self, x, y=None):
out = self.model.extract_features(x)[0]
out = out.mean(dim=1) # mean pooling
logits = self.head(out)
loss = None
if y is not None:
loss = self.loss_func(logits, y)
return loss, logits
The original model uses Fairseq models. We have ported the models to TorchAudio models and Onnx formats.
Download both the parameters and the model config under TorchAudio version
in Pretrained models.
from torchaudio.models import wav2vec2_model
class AvesTorchaudioWrapper(nn.Module):
def __init__(self, config_path, model_path):
super().__init__()
# reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html
self.config = self.load_config(config_path)
self.model = wav2vec2_model(**self.config, aux_num_out=None)
self.model.load_state_dict(torch.load(model_path))
self.model.feature_extractor.requires_grad_(False)
def load_config(self, config_path):
with open(config_path, 'r') as ff:
obj = json.load(ff)
return obj
def forward(self, sig):
# extract_feature in the sorchaudio version will output all 12 layers' output, -1 to select the final one
out = self.model.extract_features(sig)[0][-1]
return out
torchaudio_model = AvesTorchaudioWrapper(config_path, model_path)
torchaudio_model.eval()
Download the parameters and the model config under Onnx version
in Pretrained models.
NOTE: We observed that the Onnx version of AVES-all
could have large relative differences compared to the original version when the output values are close to zero. The TorchAudio versions don't have this problem.
import onnxruntime
ort_session = onnxruntime.InferenceSession(model_path)
ort_inputs = {ort_session.get_inputs()[0].name: sig}
ort_outs = ort_session.run(None, ort_inputs)
onnx_out = ort_outs[0]
Configuration | Pretraining data | Hours | Link to pretrained weights | TorchAudio version | Onnx version |
---|---|---|---|---|---|
AVES-core |
FSD50k + AS (core) | 153 | Download | Download Config | Download |
AVES-bio |
core + AS/VS (animal) |
360 | Download | Download Config | Download |
AVES-nonbio |
core + AS/VS (non-animal) |
360 | Download | Download Config | Download |
AVES-all |
core + AS/VS (all) |
5054 | Download | Download Config | Download |