aqibsaeed / Sensor-Transformer

Transformer Network for Time-Series, Sensor and Wearable Data

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sensor Transformer (SeT)

Adaptation of Vision Transformer (ViT) for Time-Series and Sensor Data in Tensorflow.

Problems/Datasets

Tools

Install

pip install sensortransformer

Usage

import argparse
import tensorflow as tf
from sensortransformer import set_network

parser = argparse.ArgumentParser()
parser.add_argument("--signal-length", type=int)
parser.add_argument("--segment-size", type=int)
parser.add_argument("--num_channels", type=int)
parser.add_argument("--num_classes", type=int)
args = parser.parse_args()

"""
TF-Data objects, see data.load_data function.
Instances must be of shape x = (batch, signal_length, num_channels)
y = (batch, num_classes)
"""
ds_train, ds_test = ...

model = set_network.SensorTransformer(
        signal_length=args.signal_length,
        segment_size=args.segment_size,
        channels=args.num_channels,
        num_classes=args.num_classes,
        num_layers=4,
        d_model=64,
        num_heads=4,
        mlp_dim=64,
        dropout=0.1,
)
model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.Adam(),
      metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
model.fit(ds_train, epochs=50, verbose=1)
model.evaluate(ds_test)

Thanks to Phil Wang for open-sourcing Pytorch implementation of ViT

About

Transformer Network for Time-Series, Sensor and Wearable Data

License:MIT License


Languages

Language:Python 100.0%