Support for NormSoftmax
catid opened this issue · comments
Based on this paper: https://openreview.net/pdf?id=4g7nCbpjNwd
Would require editing this line:
And replacing the * scale with:
if self.norm_softmax:
dots = dots / torch.clamp(dots.std(dim=-1, keepdim=True), min=1e-6)
else:
dots *= scale
And then something similar in the other flash attention path
@catid oh interesting, reminds me a bit of https://arxiv.org/abs/2005.09561
there will also be a temperature involved
have you tried this? maybe i can run a quick experiment tonight
it won't be compatible with flash attention
NormSoftmax CIFAR-10 benchmark results at epoch=60 using ViT-tiny:
baseline : 77.69%
sqrtd: 76.39%
inf: 77.53%
NormSoftmax CIFAR-10 benchmark results at epoch=300 using ViT-tiny:
baseline: 85.19%
inf: 85.07%
Manages to get about the same result without the extra parameters
another engineering obstacle would be handling a masked standard dev
yea, let me run it tonight on enwik8, but if i don't see anything notable on the first or second try, probably will just drop this
@lucidrains The masked stddev is like this right? https://github.com/catid/cifar10deepspeed/blob/fe5b399c5ab5f3ed11235d3dbe72952ce7c2be46/models/vit_small.py#L75
I think that's what I'm testing
@catid i'm thinking for autoregressive text generation (gpt), the triangular causal mask. you are masking out the diagonal?
Yeah I'm just copying your vit_for_small_dataset.py
@catid ohh ok, do you see anything? have you ran the experiments yourself? never trust anything a paper says unless you see the curves in front of you 😆
The results I shared above are from my setup
@catid wow! ok, i actually put a lot of weight from results from internet randos
ok, let me try it tonight!
@catid wait, your results show norm softmax to be worse than baseline? is that accuracy?
@catid can you share a wandb report with training curves?
I dunno I mean the numbers are pretty close and I only ran N=1 trial so not sure if one method produces better accuracy than the other. Also I don't have wandb integrated into my scripts yet (haven't learned how to use that yet).
ah, looks to be a negative result.