ViT with SAM(Sharpness-Aware Minimization)
This repo is basically trying to reproduce result of "WHEN VISION TRANSFORMERS OUTPERFORM RESNETS WITHOUT PRE-TRAINING OR STRONG DATA AUGMENTATIONS".
It is using Sharpness-Aware Minimization(SAM) on ViT.
SAM paper provides an official implementation using JAX and an also implementation using Pytorch.
Based on JAX implementation, I implemented Tensorflow version.
First Try
Above is the result compare ViT model with/without SAM for CIFAR 10 classification.
ViT is famous for it is very hard to train(origin version).
My version add some representation MLP before head layer, but it still can't converge(blue line).
And with SAM(orange line), it gradually converge.