mcbal / spin-model-transformers

Physics-inspired transformer modules based on mean-field dynamics of vector-spin models in JAX

Home Page:https://mcbal.github.io/post/spin-model-transformers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Spin-model transformers

Install

pip install -e .[dev]
pre-commit install
pre-commit run --all-files

Examples

import jax
from spin_model_transformers import SpinTransformer


key = jax.random.PRNGKey(2666)
x_key, mod_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(1, 256, 512))
transformer = SpinTransformer(depth=6, dim=512, num_heads=1, beta=1.0, key=mod_key)

out = jax.vmap(transformer)(x)  # (1, 256, 512)