inv_freq seems not calculated right
dwzhu-pku opened this issue · comments
Hello, I'm thrilled to see that linear and NTK interpolation have been elegantly combined to create a much stronger interpolation strategy—YARN. However, while going through the code in modeling_llama.py, I find myself a bit confused by the calculation of inv_freq
, particularly at line398.
According to the YaRN paper, in equation 23, it is stated as follows:
Consequently, we can derive:
However, in the paper, the calculation of
Hence, I think there might be some problem with equation 25 and also with line398
. Perhaps we can revise the yarn
function as follows, since I've empirically found that this fix can further enhance performance:
def revised_yarn(self, device):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor
inv_freq = inv_freq / ((1-inv_freq_mask)*self.scale + inv_freq_mask)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor)
So in the code (and equation 25), the scaling and mixing is applied to theta, but according to equation 23 (which is probably incorrect) the scaling is applied to the wavelength, lambda - so the division happens in a different place.
What data do you have on the performance of your YaRN implementation?
Hello! That's an interesting observation... Maybe my derivation was wrong but shouldn't a "stretching" of the wavelength be exactly equivalent to a "compression" of the frequency? Both
Lets pick a dimension
Maybe I'm missing something here, please correct me if I'm wrong...
shouldn't a "stretching" of the wavelength be exactly equivalent to a "compression" of the frequency
That is certainly true, and for plain interpolation, multiplying wavelength by the scale is as same as dividing theta by the scale. But the difference is in the mixing. This is what the Python implementation does:
And this is what equation 23 in the paper implies:
Yes, but because
No, they are not equivalent. Substituting e.g.
But we get this for equation 2:
To put it another way: The inverse of a sum is not equal to the sum of the inverses. This is essentially what is going on when
Thanks, it has never occured to me that mixing frequencies vs mixing wavelengths with a ramp function would give such a difference. I will take some time and correct this mistake...
Here's two contour plots from wolfram, it is clear that the ramp function is not equal under the two scenarios:
For mathematical correctness' sake, I will fix this mistake in the v2 of the preprint. However please do understand that this ramp is arbitrary, and did not have any robust hypothesis supporting its existence. We could have just used an heaviside step function instead of the linear ramp. (However since neural networks might not like that discontinuity, we chose the linear ramp.)
Thanks to Cebtenzzre and bloc97 for your insightful discussion! I conducted tests on two implementations mentioned earlier: the original YARN and the revised YARN, using the llama-7b model (version 1, not llama2) on GovReport. I configured the scaling factor to 48 (i.e., stretching to support 96k) and sampled 50 texts longer than 64k from GovReport for evaluation. Note that fine-tuning was not performed, so these results should be easily reproducible.
I varied the input length from 1k to 64k and employed the stretched model to calculate perplexity. Since it is totally within the supported context window, which is 96k, I didnot use sliding window for simplicity.The results are presented in the plot below. Interestingly, in this scenario, the original version performs relatively well when the input length is below 8k, whereas the revised YARN exhibits advantages as the input length increases. This phenomenon appears to be non-coincidental.
As previously discussed, in the revised version, the wavelength
On the other hand, in the original version, the wavelength
Here are some observations:
- the original yarn stretchs wavelength not as much as the revised one. this may explain why the original one performs better when input sequence is short but worse when it gets longer:
- Although these two formulations can be viewed as different ramp functions, the major problem with the original one is that, as the scaling factor
$s$ becomes very large, it weakens the impact it imposes on the wavelength. The reason is that$\frac{1-\gamma_d}{s}\rightarrow 0$ as$s$ becomes very large, for instance, 48. Consequently,$\lambda_d' \rightarrow \frac{\lambda_d}{\gamma_d}$ .
- Although these two formulations can be viewed as different ramp functions, the major problem with the original one is that, as the scaling factor s becomes very large, it weakens the impact it imposes on the wavelength. The reason is that 1−γds→0 as s becomes very large, for instance, 48. Consequently, λd′→λdγd.
Interestingly we are also currently investigating the special case where
heaviside step function where its center
Consequently, dimensions where
This "Truncated" RoPE or "NoRoPE" embedding scheme applied during pretraining could potentially allow very long extrapolation capabilities, we are looking at testing this hypothesis in the near future.