tk-rusch / coRNN

Official code for Coupled Oscillatory RNN (ICLR 2021, Oral)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Coupled Oscillatory Recurrent Neural Network (coRNN)
[ICLR 2021 Oral]

This repository contains the implementation to reproduce the numerical experiments of the International Conference on Learning Representations (ICLR) 2021 [oral] paper Coupled Oscillatory Recurrent Neural Network (coRNN): An accurate and (gradient) stable architecture for learning long time dependencies

Requirements

pytorch 1.3+
torchvision 0.4+
torchtext 0.6+
numpy 1.17+
spacy v2.2+

If you want to run the experiments on a GPU, please make sure you have installed the corresponding cuda packages.

Example

The coRNN cell can be implemented in pytorch as easy as this:

from torch import nn
import torch

class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, dt, gamma=1., epsilon=1.):
        super(coRNNCell, self).__init__()
        self.dt = dt
        self.gamma = gamma
        self.epsilon = epsilon
        self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)

    def forward(self,x,hy,hz):
        hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
                                   - self.gamma * hy - self.epsilon * hz)
        hy = hy + self.dt * hz

        return hy, hz

Datasets

This repository contains the codes to reproduce the results of the following experiments for the proposed coRNN:

  • The Adding Problem
  • Sequential MNIST
  • Permuted Sequential MNIST
  • Noise padded CIFAR-10
  • HAR-2
  • IMDB

The data sets for the MNIST/CIFAR-10 task and the IMDB task are getting downloaded through torchvision and torchtext, respectively. The data set for the HAR-2 has to be downloaded and preprocessed according to the instructions mentioned in the paper.

Results

The results of the coRNN for each of the experiments are:

Experiment Result
sMNIST 99.4% test accuracy
psMNIST 97.3% test accuarcy
Noise padded CIFAR-10 59.0% test accuracy
HAR-2 97.2 test accuracy
IMDB 87.4% test accuracy

Citation

If you found this work useful, please consider citing

@inproceedings{rusch2021coupled,
  title={Coupled Oscillatory Recurrent Neural Network (coRNN): An accurate and (gradient) stable architecture for learning long time dependencies},
  author={Rusch, T. Konstantin and Mishra, Siddhartha},
  booktitle={International Conference on Learning Representations},
  year={2021}
}

About

Official code for Coupled Oscillatory RNN (ICLR 2021, Oral)

License:MIT License


Languages

Language:Python 100.0%