dslisleedh / FF-jax

Unofficial implementation of forward-forward algorithm using jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FF-jax[Paper]

Unofficial implementation of Forward-Forward algorithm by jax.

Usage

# download script from git
git clone https://github.com/dslisleedh/FF-jax.git
cd FF-jax

# create environment
conda create --name <env> --file requirements.txt
conda activate <env>
# if this not working, install below packages manually  
# jax, jaxlib (https://github.com/google/jax#installation)  
# einops, tensorflow, tensorflow_datasets, tqdm, hydra-core, hydra-colorlog, omegaconf, gin-config  

# run ! 
python train.py

You can easily change train setting under ./config/hparams.gin # config.yaml is for hydra that create and set working directory

Hyperparameters

  • Losses
    • mse_loss
    • softplus_loss (used in Original Paper)
    • probabilistic_loss
    • symba_loss
    • swish_symba_loss
  • Optimizers
    • SGD
    • MomentumSGD
    • NesterovMomentumSGD
    • AdaGrad
    • RMSProp
    • Adam
    • AdaBelief
  • Initializers
    • jax.nn.initializers.lecun_normal
    • jax.nn.initializers.glorot_normal
    • jax.nn.initializers.he_normal
    • jax.nn.initializers.variance_scaling
  • and others like n_layers, n_units, ...

TODO

  • Add Local Conv With Peer Normalization

What about...

  • add online training model?

About

Unofficial implementation of forward-forward algorithm using jax


Languages

Language:Python 100.0%