pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

about reference of weight init according to layer depth or layer id

SeunghyunSEO opened this issue · comments

hi, first of all, thank you for the nice opensource project !
i just have been reading your model code and found it initialize model weights following num_layers or layer_id.
it is not conventional like kaming init (std=1/\sqrt{fan_in}) or GPT-2 init (std=0.02).
and it also dost not look like MuP or something.
so i just want to know if there are any references or it's just empirical for training stability.

edit) i forgot std=0.02/\sqrt{depth} init for output layers of residual block is from GPT-2 paper. sry ! just wondering where depth_init is from

Hi @SeunghyunSEO - sorry for the delay, didn't see this earlier.
To your question - the depth init came about from research last summer when we we were doing work on parallel attention blocks. I did a comparison sweep and adding the depth init was the winner, so have continued to use that.
I'm not sure where the concept came from though - it bubbled up in discussions with IBM research a while back.
I did see that Olmo was also using this, and they referred to it as a "mitchell init" but I was not able to find anything in arxiv on it.
Anyway, short answer is it's empirically based. We haven't done a sweep on it though since llama3 came out so maybe we will revisit it in the future but it continues to perform well in our training runs.
Hope that helps!

thank you for the kind answer @lessw2020 !
i guess it makes sense because Mitchell init's layerwise output variance would be more consistent compared to GPT init