nan-wang / executor-audio-ASTransformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

executor-audio-ASTransformer

This executor implements the pure Attention based Audio Spectral Transformer(AST) model. This model has achieved SOTA results on various publicly available audio datasets. This executor will return the embedding and the prediction of an audio file based on the trained model.

Initialisation

This executor takes a DocumentArray on /index request. Each Document should contain the complete name of the audio file in the filename tag.

from jina import  Document, DocumentArray
import os

dir = os.getcwd()
filename= os.path.join(dir,'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename':filename} )])

The inputs and outputs of the document after using this executor.

Input

  • filename of the audio file.

Output

  • embedding generated by the transformer model.
  • prediction by the model.

Usage

The executor can be used with a pre-trained model or a fine-tuned model.

via Docker image (recommended)

from jina import Flow
	
f = Flow().add(uses='jinahub+docker://executor-audio-ASTransformer')

via source code

from jina import Flow
	
f = Flow().add(uses='jinahub://executor-audio-ASTransformer')
  • To override __init__ args & kwargs, use .add(..., uses_with: {'key': 'value'})
  • To override class metas, use .add(..., uses_metas: {'key': 'value})

Examples

Using the pre-trained model on Audioset dataset

This is the default setting of the executor and does not require any parameter to be passed to the executor. The default parameters are already calculated in the AST implementation in these lines. The default parameters used in the executors can be found here.

from jina import Document, DocumentArray
import os
from jina import Flow

filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])

f = Flow().add(uses='jinahub://ASTransformer_encoder', install_requirements=True, force=True)
with f:
    responses = f.post(on='index', inputs=doc, return_results=True)
    print(responses)

Using a fine-tuned model

The AST can be fine-tuned on a specific dataset and the final model can be used with the executor as well. The executor will require the following parameters.

  • total_labels: Total classes in the dataset on which the model is fine-tuned.
  • input_target_dim: Length of the input signal. It is calculated as 10*t where t is the time of an audio clip in seconds.
  • dataset_mean_std: Mean and standard deviation of the dataset. The process to calculate both can be found here.
  • model_path: str: Path to the fine-tuned model.

Below is an example to use it with the ESC50 dataset.

from jina import Document, DocumentArray
import os
from jina import Flow

filename = os.path.join(os.getcwd(), 'data/1-4211-A-12.wav')
doc = DocumentArray([Document(tags={'filename': filename})])

params={'total_labels': 50,
      'input_target_dim': 512,
      'dataset_mean_std': [-6.6268077, 5.358466],
      'model_path': '/home/dato6579/jobs/ast_exec/executor-audio-ASTransformer/pretrained_models/ast_esc50_best_model.pth'}
f = Flow().add(uses='jinahub://ASTransformer_encoder',uses_with=params, install_requirements=True, force=True)
with f:
    responses = f.post(on='index', inputs=doc, return_results=True)
    print(responses)

Example output

The executor predicts and store the results in the tags of the Document with the key prediction. The embedding of the audio is stored in the embedding variable.

{'docs': [{'id': '9fbafd41-34a0-11ec-8d14-b06ebf2c5fd9',
           'tags': {'filename': '<PATH>/data/1-4211-A-12.wav', 'prediction': 12.0},
           'embedding': {'dense': {'buffer': 'mqpsPwh9EUBIb0E+Go4+PqzGV77ZN3s/CSAfPwnS673iX/C9THirPuZVjz/c8VY/WqtCv9mvJj6uKSDAdBUkv7oO8j48Oju/bdYKP4WxIb+ubGW/RPQ2PwwRFD8CNIo+MIOgPjdgZT+zVZ8/3P2Bv+0C774Sh7i/0Mn4vAezEb7om2i9ekrAPzU2sT9441k/7LVVP981jT9l9ne+063FvvFstD7ilrW/6e2evxSIXT/g04Y/He2Av9EEm73kVYQ/XN3PvuaCJMDE/M6+PFiIv6Akr73OX5u/bP6mP2Q6VL+6bVE/yEZLP+bGZ7+mCJ6+z/qRvgBxpb632cm+YEEnv9iE4b8Rn5m/Tez2PXjVlD5mte2/yhCKP4kIvT60+fE/L7DnvxJfQr8f3V4/Qr9Zv5qbVD6WHtm++CKqPfRiUj8jtnU/PLO1Pgvx875kfpy/Z2sEQPI+mD4+u3g+58DKPjK5EMCTUo4/ugAPvytlwb/mljY+p2wEvgKGAsBjfDK/O27QPgEXkj9Q6lK9n0p2QETCRj88STQ/NVYivlYvF8DvFqI+cGnAvxiNFcBkoSBAbrm9Pu3ojr+tSOQ+A0mSvwYzGL46KMC/GDCGPz5flr/WFuA/eRDbvVM+j78qpSS/zWF4v5nckT3g6Ig+7rBKP0hnA7ugPu6+gqOkPi0eYj634K4/1u2xP2iPOz1qYAc+qvc1QLhERj92Zem+9QbWPuvEQz6MCwFAQIaGvCEYLj6OLC8/MY2hvsb8bb+m7py/pZ9qv7yEAEAvA5A+9CbKv1MpGL/nIc4/ZGGYPjPeT76w7cg/pg1Jv8wYcD1KaXu/vHJ3vy/5iz8KnAK/ykQrv2ykvj8INh+/NOuTP1QyKMAa2Ko+O6MsQOySA0AYP7a9vI5yvlReRUDPiAFAEECGP/XqiD4slqW/vz0MQBakbL9keZo+/WRKPvMOBb/mrUW+sGkNwHYvv7+uvSk/jjiFv3DDFz4VgLk+pB8qwB+LXT/2xwQ+pj0PwA34Ez8yfjG/AEq7Pd85nb4Up4S/TrtnvzZjZb9rL3k/aD9NPxdm5L5K7pu/IMDRP1oYn74wJ1LA6bNvvu75ob74vv2+EnPKv6ZG/L/u5Rw+wGwzP/gS2z80SqA//oSjvzTSxD7I5r4/vFmhv+gxlT8agZQ/BMLzPhRy5j/8jbY/vIy8Pw0G3z45Niw/4BnoP3fWiD+k8YE/QnWlPzcNE78Yajm/Y4elv5+WHT8rih2+ReSxPuVo/j6aWk1APJqIP2HZxj+iLwm+aNjdv0B/1bzy/ZC/4wrJvjLUFj6MBrq9ZDqZvHCZND+b+wzAjCmRv0pp9T+vmO4/C19Rv8M7ZT4+l/O/zPKIv+hg5b0IJ4W+/Y+Lv+KJRb9YtPK/lZCqPhnC2b5wpRw/qhunv6eGBb+o/qU/xCsHv09TvL7QXbu+ztZXPqVfQ0BeXNy/aHqEP100gr/oftM+nG9iP4XcJj6k8nM/9CBSP0DFtr8Z184/6EK9v6BmrD90m8W+FyPBPwaSir/3DDk/7hEdvnzZz7/O3VK/5BcqPx2pSD/yOhW/es8pPzB1cb+aHSc+QLXJP1bXQT5G97S/RF3IP6K2gT3FeJm+EhfGPzjQ4T6UT/W+NIakv4Ys/b9+ay3AFUmzPYgsIr5t4IM+4F2zO5EB3j1+2Um/eNA1P2L3u76s7Iu/aKL5vX5Jbz+GLiO/BFCbvdLdlL8392U/GcCHvwI/F8DHnxS//2ThP0RkFkDeqle/voLXP1o4jj84/r8/yXpcv7ivaz0krgLAWuO7vTqCnr8EJM+/XUShP9psxj9C7cE+8gh4v8T4VL8SYBtAuO1PwBLj3r3n/3u/TM+Jv3xsqz51oiA/eokbwAU9DD8KFbI/bIyjvWgtVUCg+Y+/MH+Rv+AgBcD+Vhs/EO8QwIzpGUC6qQDAdljsPm9ZAcDSMZg/MegXQAKRqD9M9fO/YHMmvyci9b+WfBXAuI0NP3qLIUDK3OI9IHXyPwLMcj9iGQw+mYATQOHdBr5YfZ2+qecUPsF11748To2/ZC6rP0rYqj5CFJg++MX5PvqHWL8N6ZI/TFXPv5Z3iz+27z49v8qqv265qz/IYBG/CnHFv3JfDsC568Y+LGO0PmWh0767xEVA/m4BQG+KJz4OR4S/0BD2vyjNtr/AaVu/URaUv+bQfD/xpp0//KXev+So5j/eqcY/NF/AvKwM5j6u0EO/3wIRvpQoab7gFaM/lFQPP94MnT+4EtY+WMVOPwjZGkCCiwm/4PraPHiaNj7scLK/pEDIv9jg0z0iXsM/mAfKPgH9gr9esYK+lNkEvm5cMD9me4Q+ll6LPo2/vD9UKwy+xQrYvhEyvT56hZk9kMjoPO7El77Exg296cAAQGv9zT/I2a8/ho7Cv8wxfT7FxWc+7GoRvhi10L+G+ARAQ32tvzVmkT0acaU/S8nevc6o2z/QoOQ9FvmTP7VohL1WsgM/Uoy0P4YKwL9P05I/MiEPv4hl8L/KLW8/g2/+v/xBij+BfgS/3Ze2PczRpD9gMxjARskUwEvJ6T/fIni/hz8swJQN6j6p9B6/hcGCP6TWpz21Sv+/+aM8P1heSL9cY/y/TErsP83JUj+tTpW/0gL3vRethr+6joI/92mMv5ZMqr6SiJq+3NWrvJq6OsBt/vY96lduvy9QkD8d6Go/OW4Fvt6F9T67YSpADet7v3milb/w1QHABFjbvd4zDb9WXdU/aiYaPxeYjD5AWWi9TKuoPxq6AcBEH7S/jC60P6YnxL+eFI69EtvOvzgViD6G0Ke/OE5Sv5hosz7+iTg9kyGePtbjZT+d/w3AuE0dv0j5IEBLBHg+dsfevxCEuL52/hQ+bn9mP5xY2T+4SVm/EnKZvvgT/L3K3HQ+6PwtPuw/qz9KDki/Gr2qPyi27TwSEgHAeVAEwHxazT9GdWK9sufqv2qcgz7NXLU/4k4VwLNc6j4A+Wi/72o1wPN8a77Io7c+NpiHv863179tlgbArn96PyBvPDyda5C+n6M0vn6Gmj6oFsO+VNTGv+N4Pb9G50K/BK5XwM51sD67T5m/gww0P1lQ5r9/+eI+YI4iv5xS3T/GCcG+pCtSvnCcb714GiG/ddrXv5DfuDy2bC5AOCaTv63HrL8YHOu/sgKpv/hk7T+DVtI/iapMvySVNL+Qbjk/fzZVQGUNnT4O1G4/f/0vv5DLq76EOi++DiS6v9i5mT/vSxi/PkdiPr4rCUBvUOU9f3XFvoEeDUDqQlu/2i0+vztKq76sIyy/bGx/P0af7L6I106+/EsdP38jE78HmqS+TEp8PQxMYL9BbgLAUnc9vredvb1WjhS/yjdtv1yDJ783pee+DS3jPi3iEEA+IV3AgtGuv/abyT9q+W8/4DD1PJyjRb8E9ns+1JORvwr81T+A2qe+pjdSQL/7N78b1Fi/y1j7P2M15T2fKBC/mproPsjxJz9KPGA9EPDrvkFP4759F6G/r9+kPwaBhj+EFec/zXEjPzCQSL8ybto/8JD5v/GqID7c86w/5QeXPsx1MkB8P1jAJxaeP8kaPr+kU8g9isu5vyRAUz86DXq/YNoDP7QuEj+u5jjAKfajPldknr2G0Ne/8OeOPyBiBr1sL5W/qqWNP3yEJEA4vq0/TAZ6v6H/Ij7sqbe/nlpjP/LMB8DGq8Q/zIN8P+K3yj9a988+JNzIPX1ypb+7yQi+tsX4P1lvZ74WpUk+RDS+v0BqEr+GTVK/GGuwPpOeOz8yBR2+QlYDQOpOrz4+DTg/zKwSPtY3JMBbOF8/hNg0PiAadL7ezW+/O4ltP6hfpL0PthG/7u92PqkDnL/bI3w/JZhPP1bZvb83cAfAkF61PnoOaL78w/6+dFCtv/uXDEBJUSxAGPs6Pj7DYD/COGi/MqHwvqCa6z+OKcs/7LZoP4yyfr/FWlc+tp+MvuxB/j5hSLs/nNYEv7zCJT9pTac/8pmXv9LtIb8eb8s/yDdkvwByAL3S590+IBKEvaTAvj9A4ew+/9w2PskC7z4QCKk/P1UWvkY1fb2i9na+kwNUv5kSyD/rU1I/alKAv9DTvz+MJZK/l++rP+IUHL+K/6s/', 'shape': [1, 768], 'dtype': '<f4'}}}]
 } 

Reference

About


Languages

Language:Python 100.0%