executor-audio-ASTransformer
This executor implements the pure Attention based Audio Spectral Transformer(AST) model. This model has achieved SOTA results on various publicly available audio datasets. This executor will return the embedding and the prediction of an audio file based on the trained model.
Initialisation
This executor takes a DocumentArray
on /index
request. Each Document
should contain the complete name of the audio file in the filename
tag.
from jina import Document, DocumentArray
import os
dir = os.getcwd()
filename= os.path.join(dir,'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename':filename} )])
The inputs and outputs of the document after using this executor.
Input
filename
of the audio file.
Output
embedding
generated by the transformer model.prediction
by the model.
Usage
The executor can be used with a pre-trained model or a fine-tuned model.
via Docker image (recommended)
from jina import Flow
f = Flow().add(uses='jinahub+docker://executor-audio-ASTransformer')
via source code
from jina import Flow
f = Flow().add(uses='jinahub://executor-audio-ASTransformer')
- To override
__init__
args & kwargs, use.add(..., uses_with: {'key': 'value'})
- To override class metas, use
.add(..., uses_metas: {'key': 'value})
Examples
Using the pre-trained model on Audioset dataset
This is the default setting of the executor and does not require any parameter to be passed to the executor. The default parameters are already calculated in the AST implementation in these lines. The default parameters used in the executors can be found here.
from jina import Document, DocumentArray
import os
from jina import Flow
filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])
f = Flow().add(uses='jinahub://ASTransformer_encoder', install_requirements=True, force=True)
with f:
responses = f.post(on='index', inputs=doc, return_results=True)
print(responses)
Using a fine-tuned model
The AST can be fine-tuned on a specific dataset and the final model can be used with the executor as well. The executor will require the following parameters.
total_labels
: Total classes in the dataset on which the model is fine-tuned.input_target_dim
: Length of the input signal. It is calculated as10*t
wheret
is the time of an audio clip in seconds.dataset_mean_std
: Mean and standard deviation of the dataset. The process to calculate both can be found here.model_path: str
: Path to the fine-tuned model.
Below is an example to use it with the ESC50 dataset.
from jina import Document, DocumentArray
import os
from jina import Flow
filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])
params={'total_labels': 50,
'input_target_dim': 512,
'dataset_mean_std': [-6.6268077, 5.358466],
'model_path': '/home/dato6579/jobs/ast_exec/executor-audio-ASTransformer/pretrained_models/ast_esc50_best_model.pth'}
f = Flow().add(uses='jinahub://ASTransformer_encoder',uses_with=params, install_requirements=True, force=True)
with f:
responses = f.post(on='index', inputs=doc, return_results=True)
print(responses)
Example output
The executor predicts and store the results in the tags
of the Document
with the key prediction
. The embedding of the audio is stored in the embedding
variable.
{'docs': [{'id': '9fbafd41-34a0-11ec-8d14-b06ebf2c5fd9',
'tags': {'filename': '<PATH>/data/1-4211-A-12.wav', 'prediction': 12.0},
'embedding': {'dense': {'buffer': 'mqpsPwh9EUBIb0E+Go4+PqzGV77ZN3s/CSAfPwnS673iX/C9THirPuZVjz/c8VY/WqtCv9mvJj6uKSDAdBUkv7oO8j48Oju/bdYKP4WxIb+ubGW/RPQ2PwwRFD8CNIo+MIOgPjdgZT+zVZ8/3P2Bv+0C774Sh7i/0Mn4vAezEb7om2i9ekrAPzU2sT9441k/7LVVP981jT9l9ne+063FvvFstD7ilrW/6e2evxSIXT/g04Y/He2Av9EEm73kVYQ/XN3PvuaCJMDE/M6+PFiIv6Akr73OX5u/bP6mP2Q6VL+6bVE/yEZLP+bGZ7+mCJ6+z/qRvgBxpb632cm+YEEnv9iE4b8Rn5m/Tez2PXjVlD5mte2/yhCKP4kIvT60+fE/L7DnvxJfQr8f3V4/Qr9Zv5qbVD6WHtm++CKqPfRiUj8jtnU/PLO1Pgvx875kfpy/Z2sEQPI+mD4+u3g+58DKPjK5EMCTUo4/ugAPvytlwb/mljY+p2wEvgKGAsBjfDK/O27QPgEXkj9Q6lK9n0p2QETCRj88STQ/NVYivlYvF8DvFqI+cGnAvxiNFcBkoSBAbrm9Pu3ojr+tSOQ+A0mSvwYzGL46KMC/GDCGPz5flr/WFuA/eRDbvVM+j78qpSS/zWF4v5nckT3g6Ig+7rBKP0hnA7ugPu6+gqOkPi0eYj634K4/1u2xP2iPOz1qYAc+qvc1QLhERj92Zem+9QbWPuvEQz6MCwFAQIaGvCEYLj6OLC8/MY2hvsb8bb+m7py/pZ9qv7yEAEAvA5A+9CbKv1MpGL/nIc4/ZGGYPjPeT76w7cg/pg1Jv8wYcD1KaXu/vHJ3vy/5iz8KnAK/ykQrv2ykvj8INh+/NOuTP1QyKMAa2Ko+O6MsQOySA0AYP7a9vI5yvlReRUDPiAFAEECGP/XqiD4slqW/vz0MQBakbL9keZo+/WRKPvMOBb/mrUW+sGkNwHYvv7+uvSk/jjiFv3DDFz4VgLk+pB8qwB+LXT/2xwQ+pj0PwA34Ez8yfjG/AEq7Pd85nb4Up4S/TrtnvzZjZb9rL3k/aD9NPxdm5L5K7pu/IMDRP1oYn74wJ1LA6bNvvu75ob74vv2+EnPKv6ZG/L/u5Rw+wGwzP/gS2z80SqA//oSjvzTSxD7I5r4/vFmhv+gxlT8agZQ/BMLzPhRy5j/8jbY/vIy8Pw0G3z45Niw/4BnoP3fWiD+k8YE/QnWlPzcNE78Yajm/Y4elv5+WHT8rih2+ReSxPuVo/j6aWk1APJqIP2HZxj+iLwm+aNjdv0B/1bzy/ZC/4wrJvjLUFj6MBrq9ZDqZvHCZND+b+wzAjCmRv0pp9T+vmO4/C19Rv8M7ZT4+l/O/zPKIv+hg5b0IJ4W+/Y+Lv+KJRb9YtPK/lZCqPhnC2b5wpRw/qhunv6eGBb+o/qU/xCsHv09TvL7QXbu+ztZXPqVfQ0BeXNy/aHqEP100gr/oftM+nG9iP4XcJj6k8nM/9CBSP0DFtr8Z184/6EK9v6BmrD90m8W+FyPBPwaSir/3DDk/7hEdvnzZz7/O3VK/5BcqPx2pSD/yOhW/es8pPzB1cb+aHSc+QLXJP1bXQT5G97S/RF3IP6K2gT3FeJm+EhfGPzjQ4T6UT/W+NIakv4Ys/b9+ay3AFUmzPYgsIr5t4IM+4F2zO5EB3j1+2Um/eNA1P2L3u76s7Iu/aKL5vX5Jbz+GLiO/BFCbvdLdlL8392U/GcCHvwI/F8DHnxS//2ThP0RkFkDeqle/voLXP1o4jj84/r8/yXpcv7ivaz0krgLAWuO7vTqCnr8EJM+/XUShP9psxj9C7cE+8gh4v8T4VL8SYBtAuO1PwBLj3r3n/3u/TM+Jv3xsqz51oiA/eokbwAU9DD8KFbI/bIyjvWgtVUCg+Y+/MH+Rv+AgBcD+Vhs/EO8QwIzpGUC6qQDAdljsPm9ZAcDSMZg/MegXQAKRqD9M9fO/YHMmvyci9b+WfBXAuI0NP3qLIUDK3OI9IHXyPwLMcj9iGQw+mYATQOHdBr5YfZ2+qecUPsF11748To2/ZC6rP0rYqj5CFJg++MX5PvqHWL8N6ZI/TFXPv5Z3iz+27z49v8qqv265qz/IYBG/CnHFv3JfDsC568Y+LGO0PmWh0767xEVA/m4BQG+KJz4OR4S/0BD2vyjNtr/AaVu/URaUv+bQfD/xpp0//KXev+So5j/eqcY/NF/AvKwM5j6u0EO/3wIRvpQoab7gFaM/lFQPP94MnT+4EtY+WMVOPwjZGkCCiwm/4PraPHiaNj7scLK/pEDIv9jg0z0iXsM/mAfKPgH9gr9esYK+lNkEvm5cMD9me4Q+ll6LPo2/vD9UKwy+xQrYvhEyvT56hZk9kMjoPO7El77Exg296cAAQGv9zT/I2a8/ho7Cv8wxfT7FxWc+7GoRvhi10L+G+ARAQ32tvzVmkT0acaU/S8nevc6o2z/QoOQ9FvmTP7VohL1WsgM/Uoy0P4YKwL9P05I/MiEPv4hl8L/KLW8/g2/+v/xBij+BfgS/3Ze2PczRpD9gMxjARskUwEvJ6T/fIni/hz8swJQN6j6p9B6/hcGCP6TWpz21Sv+/+aM8P1heSL9cY/y/TErsP83JUj+tTpW/0gL3vRethr+6joI/92mMv5ZMqr6SiJq+3NWrvJq6OsBt/vY96lduvy9QkD8d6Go/OW4Fvt6F9T67YSpADet7v3milb/w1QHABFjbvd4zDb9WXdU/aiYaPxeYjD5AWWi9TKuoPxq6AcBEH7S/jC60P6YnxL+eFI69EtvOvzgViD6G0Ke/OE5Sv5hosz7+iTg9kyGePtbjZT+d/w3AuE0dv0j5IEBLBHg+dsfevxCEuL52/hQ+bn9mP5xY2T+4SVm/EnKZvvgT/L3K3HQ+6PwtPuw/qz9KDki/Gr2qPyi27TwSEgHAeVAEwHxazT9GdWK9sufqv2qcgz7NXLU/4k4VwLNc6j4A+Wi/72o1wPN8a77Io7c+NpiHv863179tlgbArn96PyBvPDyda5C+n6M0vn6Gmj6oFsO+VNTGv+N4Pb9G50K/BK5XwM51sD67T5m/gww0P1lQ5r9/+eI+YI4iv5xS3T/GCcG+pCtSvnCcb714GiG/ddrXv5DfuDy2bC5AOCaTv63HrL8YHOu/sgKpv/hk7T+DVtI/iapMvySVNL+Qbjk/fzZVQGUNnT4O1G4/f/0vv5DLq76EOi++DiS6v9i5mT/vSxi/PkdiPr4rCUBvUOU9f3XFvoEeDUDqQlu/2i0+vztKq76sIyy/bGx/P0af7L6I106+/EsdP38jE78HmqS+TEp8PQxMYL9BbgLAUnc9vredvb1WjhS/yjdtv1yDJ783pee+DS3jPi3iEEA+IV3AgtGuv/abyT9q+W8/4DD1PJyjRb8E9ns+1JORvwr81T+A2qe+pjdSQL/7N78b1Fi/y1j7P2M15T2fKBC/mproPsjxJz9KPGA9EPDrvkFP4759F6G/r9+kPwaBhj+EFec/zXEjPzCQSL8ybto/8JD5v/GqID7c86w/5QeXPsx1MkB8P1jAJxaeP8kaPr+kU8g9isu5vyRAUz86DXq/YNoDP7QuEj+u5jjAKfajPldknr2G0Ne/8OeOPyBiBr1sL5W/qqWNP3yEJEA4vq0/TAZ6v6H/Ij7sqbe/nlpjP/LMB8DGq8Q/zIN8P+K3yj9a988+JNzIPX1ypb+7yQi+tsX4P1lvZ74WpUk+RDS+v0BqEr+GTVK/GGuwPpOeOz8yBR2+QlYDQOpOrz4+DTg/zKwSPtY3JMBbOF8/hNg0PiAadL7ezW+/O4ltP6hfpL0PthG/7u92PqkDnL/bI3w/JZhPP1bZvb83cAfAkF61PnoOaL78w/6+dFCtv/uXDEBJUSxAGPs6Pj7DYD/COGi/MqHwvqCa6z+OKcs/7LZoP4yyfr/FWlc+tp+MvuxB/j5hSLs/nNYEv7zCJT9pTac/8pmXv9LtIb8eb8s/yDdkvwByAL3S590+IBKEvaTAvj9A4ew+/9w2PskC7z4QCKk/P1UWvkY1fb2i9na+kwNUv5kSyD/rU1I/alKAv9DTvz+MJZK/l++rP+IUHL+K/6s/', 'shape': [1, 768], 'dtype': '<f4'}}}]
}