This branch provides preliminary examples using S5 for language modeling. We have reused the pytorch dataloading code and ported some of the pytorch source from the Safari: Convolutions for Sequence Modeling repository into JAX, including the core code required for the Associative Recall in-context learning task and WikiText-103 as performed in the H3 and Hyena papers. The training loop implementation of this branch was inspired by the TECO repository.
We hope to continue to add experiments on larger datasets and plan to merge the infrastructure developed in this branch (pmapping, checkpointing, more general experiment loading) into the main branch.
H3 proposed combining data-controlled multiplicative gating with SSMs (in particular S4D) and
showed strong performance (when combined with a couple of attention layers in a hybrid model) on OpenWebText , WikiText-103 and the Pile.
Hyena generalized the notion of multiplicative gating and also replaced the SSMs of H3 with implicitly parameterized convolutions. This latter change was motivated because using SSMs for the convolution seemed to not perform as well on
challenging versions of synthetic in-context learning tasks (e.g. associative recall with 131K length sequences and vocab size of 30). Hyena showed strong attention-free performance and the ability to match perplexity and downstream performance of Transformers when trained on The Pile. The Hyena paper also shows interesting results suggesting a correlation between performance on the associative recall tasks and performance on The Pile.
A disadvantage of Hyena's implicit convolutions is that it loses the recurrent formulations (and fast autoregressive generation) that SSM methods such as S4/S5 provide.
Are the implicitly parameterized convolutions necessary and are SSM methods not expressive enough? We were curious how S5 would do in this setting, so we simply replaced the Hyena implicit convolution filters with S5 SSMs. We will refer to this as Hyena-S5 since we used the exact Hyena operator implementation (ported into JAX) and simply replaced the implicitly parameterized convolutions. However, we also note that when only order 2 operators are used (as we do here), this is essentially H3, but with S5 used for the SSM.
Our results indicate that Hyena-S5 can solve the challenging associative recall in-context learning tasks and can also outperform Transformers on WikiText without any attention layers. We fully recognize that these are relatively small scale results these days, but it is at least another datapoint that SSM methods such as S5 combined with multiplicative gating can potentially be very effective. We hope to scale up these and other ideas to larger datasets soon, stay tuned!
Use these commands to run Hyena and Hyena-S5 on the 131K sequence length, vocab size 30 and 40 associative recall tasks (be sure to fill out the data_dir and wandb info in the corresponding config file):
You can run a small test (seq_len=20, vocab=10) that should run on a cpu by running this command:
python train.py -o output_dir_name -c configs/hyena_S5/associative_recall_20_10_S5.yaml
Hyena
python train.py -o output_dir_name -c configs/hyena/associative_recall_131K_30_hyena.yaml
Hyena-S5
python train.py -o output_dir_name -c configs/hyena_S5/associative_recall_131K_30_S5.yaml
We averaged Hyena-S5 over 10 random seeds and obtained an average accuracy of 97.2%, which matches the performance reported for Hyena with implicit convolutions. For comparison we have added our result to the table from the Hyena paper:
See an example wandb run for both Hyena and Hyena-S5 here.
We also include runs for the vocab 40 setting.
Hyena
python train.py -o output_dir_name -c configs/hyena/associative_recall_131K_40_hyena.yaml
Hyena-S5
python train.py -o output_dir_name -c configs/hyena_S5/associative_recall_131K_40_S5.yaml
See an example wandb run for both Hyena and Hyena-S5 here.
Run these commands to run Hyena and Hyena-S5 on Wikitext-103:
Hyena
python train.py -o output_dir_name -c configs/hyena/wikitext_hyena.yaml
Hyena-S5
python train.py -o output_dir_name -c configs/hyena_S5/wikitext_S5.yaml
See an example wandb run for both Hyena and Hyena-S5 here. Hyena-S5 achieves 18.3 perplexity after 100K steps. For comparison, we have added our result to the table from the Hyena paper:
Please reach out if you have any questions or feedback.
-- The S5 authors.