lucidrains / gated-state-spaces-pytorch

Implementation of Gated State Spaces, from the paper "Long Range Language Modeling via Gated State Spaces", in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Gated State Spaces - Pytorch

Implementation of Gated State Spaces, from the paper Long Range Language Modeling via Gated State Spaces, in Pytorch. In particular, it will contain the hybrid version containing local self attention with the long-range GSS.

It will also contain a few more settings to compare state spaces to a sequence-wise GLU depthwise conv, and even simpler, a parameterized exponential moving average along the sequence dimension. So we get to the bottom of whether state spaces are worth it, or whether it is really all about the O(L log(L)) FFT convolution trick. Results will be shared in the readme.

I will also pit the GSS module against the Path-X challenge and see how well it does.

Update: This paper has beat S4 on LRA using multi-headed EMA + single head attention.

Install

$ pip install gated-state-spaces-pytorch

Usage

import torch
from gated_state_spaces_pytorch import GSS

gss = GSS(
    dim = 512,                  # dimension
    dim_expansion_factor = 4,   # hidden dimension (expansion factor x dim) = 2048
    dss_kernel_N = 512,
    dss_kernel_H = 256
)

x = torch.randn(1, 65536, 512)

out = gss(x) # (1, 65536, 512)

Gated state spaces language model

import torch
from gated_state_spaces_pytorch import GatedStateSpacesLM

gss_lm = GatedStateSpacesLM(
    num_tokens = 20000,
    depth = 12,
    dim = 512,
    dim_expansion_factor = 4,
    dss_kernel_N = 512,
    dss_kernel_H = 256
)

ids = torch.randint(0, 20000, (1, 1024))

logits = gss_lm(ids) # (1, 1024, 20000)

Todo

  • enwik8
  • gss lm class
  • add dsconv + learned ema
  • add attention.

Citations

@inproceedings{Mehta2022LongRL,
    title   = {Long Range Language Modeling via Gated State Spaces},
    author  = {Harsh Mehta and Ankit Gupta and Ashok Cutkosky and Behnam Neyshabur},
    year    = {2022}
}
@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

About

Implementation of Gated State Spaces, from the paper "Long Range Language Modeling via Gated State Spaces", in Pytorch

License:MIT License


Languages

Language:Python 100.0%