google / autol2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

This code trains a Wide ResNet on different datasets and includes the AutoL2 algorithm described in the paper. Implemented by Aitor Lewkowycz, based on code by Sam Schoenholz. Requirements can be installed from requirements.txt. It is made to work on TPUs. Can also work on GPU by adding -noTPU and installing the GPU jaxlib package of https://github.com/google/jax.

Commands to generate data used for figures

Figure 1a.

for L2 in L2LIST:
  do
  python3 jax_wideresnet_exps.py -L2=$L2 -epochs=200 -std_wrn_sch 
  python3 jax_wideresnet_exps.py -L2=$L2 -physicalL2 -epochs=0.02 -std_wrn_sch # This is evolved for a time 0.02/eta/lambda=0.1/lambda epochs. 
  done

Figure 1b is generated by comparing the performance of models with our prediction.

for L2 in L2LIST:
  do
  python3 jax_wideresnet_exps.py -L2=$L2  -epochs=2000
  done

To obtain the t* prediction, we run the following.

python3 jax_wideresnet_exps.py -L2=0.01  -epochs=2

Figure 1c: Evolve with lr=0.2 for 200 epochs with L0=0.1 and L2_sch vs L2=0.0001.

python3 jax_wideresnet_exps.py -L2=0.1 -L2_sch
python3 jax_wideresnet_exps.py -L2=0.0001 -noL2_sch

The Wide ResNet experiments in Figure 2 are similar.

for lr in LRLIST:
  do
  for L2 in L2LIST:
    do
      python3 jax_wideresnet_exps.py -L2=$L2 -physicalL2 -epochs=0.1 -nomomentum -noaugment
    done
  done

About

License:Apache License 2.0


Languages

Language:Python 100.0%