NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
Posted by bloc97@reddit | LocalLLaMA | View on Reddit | 8 comments
I've seen the posts about SuperHOT and just recently, the paper from Meta which uses RoPE interpolation, and I've noticed an immediate improvement that can be brought to this method. Basically if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of \~600x)
Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)
To my surprise, this method works extremely well, so much so that you don't even need to fine tune the LLaMA 7B model for 4096 context size! The perplexity degradation is minimal. I'm sure with fine tuning this would become even better.
Enough explanations, here's some empirical results. All the perplexity measurements are done on LLaMA 7b with the tau/scrolls · Datasets at Hugging Face dataset (I only used a subset of gov_report).
Here's a graph showing the average perplexity of LLaMA 7b on a set of 40 very long prompts (12k+ context size). Compared to changing the scale (from SuperHOT, which was set to 4), we change a factor alpha, which when equal to 8 provides the same context size increase but with much less perplexity degradation. All without any finetuning!
[Graph showing the average perplexity of LLaMA 7b on set of 40 very long prompt (12k+ context size) with previous and new interpolation scheme](
Here's more results, showing more scale and alpha factors.
[Here's more results, showing more scale and alpha factors.](
Zoomed-in version of the second graph, showing the details.
[Zoomed-in version of the second graph, showing the details.](
Code can be found on a colab notebook with a test example. NTKAwareScaledRotaryEmbedding.ipynb - Colaboratory (google.com)
Again, the changes to the RoPE code is only 3 lines.
Also just for fun, I've pasted the entirety of the Meta paper [2306.15595] Extending Context Window of Large Language Models via Positional Interpolation (arxiv.org), with the abstract and conclusions removed, and asked OASST-sft-6-llama-30b some questions. I think it fared pretty well given it hasn't been fine tuned at all for 6000+ context size. I used an alpha of 8 for this prompt.
The entire paper is about 6200 tokens. These prompts can be found in the colab notebook linked above. The following questions are simply appended at the end of the prompt which contains the entire research paper.
QUESTION
Please give me a brief summary of this research paper in a few bullet points.
ANSWER
* Large Language models (LLMs) often have a pre-defined context window size, but sometimes they need to operate on longer contexts, e.g., in conversation systems or summarization tasks.
* Extending the context window of an existing pre-trained LLM is challenging because direct extrapolation of positional encodings can lead to poor performance.
* Researchers propose "Position Interpolation" to extend the context window of LLMs by downscaling the positional encodings.
* They prove that interpolated positional encodings have a much smaller upper bound than extrapolated ones, allowing the model to adapt to longer contexts more efficiently.
* They experimentally validate the effectiveness of Position Interpolation on LLaMA models, reducing the perplexity gap between the original and extended contexts.
* They compare Position Interpolation with direct extrapolation and find that Position Interpolation works better.
* They apply Position Interpolation to LLaMa models and fine-tune them on the Pile dataset, finding that the models can effectively utilize longer contexts.
* They also evaluate the models' abilities to attend to longer contexts and perform well on benchmarks within the original context size.
* They apply Position Interpolation to long document summarization and obtain competitive results.
* They provide insights into the interpolation/extrapolation tradeoff and discuss potential improvements.
* They plan to explore more advanced methods like data augmentation and multi-task learning to improve the interpolation.
QUESTION
Please write me the abstract for this paper.
ANSWER
This paper presents a technique called "Position Interpolation" for extending the context window size of pre-trained Large Language Models (LLMs) without modifying their architecture. Previous methods like ALiBi and LeX were unable to efficiently extend the context window size of LLMs, leading to poor performance. The authors propose Position Interpolation, which involves downscaling the positional encoding of the LLM's attention mechanism to match the desired longer context window size. They prove mathematically that interpolated attention scores are more stable than extrapolated ones, allowing the model to adapt to longer contexts with just a few hundred fine-tuning steps. Experiments showed that Position Interpolation enables LLMs to perform well on longer contexts, reducing perplexity by up to 0.8 points on average, and generating strong summarizations. Extending the context window size to 32768 led to slight degradation on the original benchmarks, but the models adapted to longer texts. The authors conclude that interpolating positional encodings allows the model to attend to relevant evidence in the longer contexts, and the stability of the interpolated attention score helps the model avoid catastrophic behaviors.
QUESTION
How many steps was the model fine tuned for the final results? Give a short answer.
ANSWER
The authors of the paper "Extending Context Windows of Large Language Models via Position Interpolation" fine-tuned the model for 1000 batches for the experiments presented in the paper.
QUESTION
How big is the interpolation bound compared to the extrapolation bound? Give a short answer.
ANSWER
The interpolation bound is at least 2 - 29.7 * 10\^(-6) times smaller than the extrapolation bound.
Sorry for the large dump of information on reddit, I'm too lazy to write a blog post for this. I might give a more detailed explanation on how I derived the formula used to calculate the base if enough people are interested.
I did not test fine-tuning performance as I do not have the resources or the time to fine tune an LLM, I just derived this formula during lunch and experimented with it. However, I think that this method will do even better with fine tuning. Also thanks to the people behind the SuperHOT blogpost, it was their hard work that inspired me and allowed me to make this contribution for everyone!
Finally, I really hope this post will inspire others to start experimenting on ways to improve LLMs. There's so much to learn and so much left to discover! What a time to be alive!
MoonOrigin@reddit
So great. And so chill. Fantastic work.
shamblack19@reddit
Hey Im trying to understand the math here and have some questions:
1: Seems like the same base will be used at all positions along the embedding dimension, right? I thought the goal was to increase the base at higher embedding positions, where is that happening here?
2: I don't understand what (dim/(dim-2)) is doing. Why (dim-2)??
Let me know if my understanding is wrong, still trying to wrap my head around the intuition
disperaller@reddit
you have to remember to raise the base to the power of -2i/d, if you separate the term inside the parathesis (base * a ** (dim / dim-2)) to (base, a**(dim / dim-2)), the first element raised to the -2i/d power is the same as original designed, the second term a**(dim / dim-2) raises to the power of -2i/d will change based on i, so if i is small, then this thing will become 1, has no effect on the low dimension values, for large i, this thing will also grow, causing the the high dimension values to show interpolation effects.
me neither :<
shamblack19@reddit
You’re awesome!!!! I actually had to go onto Desmos and plot out the math but I fully understand it now!! Really appreciate you haha
Glad you’re also confused about that exponent. It always evaluates to a constant that’s close to 1, I’m thinking it’s redundant.
pseudonerv@reddit
can you eli18 how do you go from NTK theory to
???
bloc97@reddit (OP)
Let me try an ELI14 instead.
RoPE behaves like a clock. Your 12 hours wall clock is basically a RoPE of dimension 3 with a base of 60. So for each second, the minute hand turns 1/60th of a minute, and for each minute, the hour hand turns 1/60th.
Now if you slowed down time by a factor of 4x, that is a linear RoPE scaling used in SuperHOT. Unfortunately now it is really really hard to distinguish each second, because now the seconds hand barely moves each second. So if someone gave you two different times, which is only different by a single second, you won't be able to distinguish them from afar (let's say the NNs have myopia because that's basically what NTK predicts)
Now NTK-Aware RoPE scaling does not slow down the seconds. One second is still one second, but it slows down minutes by a factor of let's say 1.5, and the hours by a factor of 2. This way you can fit 90 minutes in a hour, and fit 24 hours in half a day. So now you basically have a clock that can measure 129.6k seconds instead of 43.2k seconds.
Because you don't need a precise measurement of the hour hand when looking at the time, scaling the hours more compared to seconds is crucial. You don't want to lose the precision of the seconds hand, but you can afford to lose precision on the minutes hand and even more on the hours hand.
Then, it's just a matter of deriving the base change formula in order to obtain such a scaling. (where less precise dimensions are scaled more and more)
I hope that makes sense!
Edit: Fixed small calculation error, the seconds shouldn't change...
pseudonerv@reddit
Thanks for the detailed eli14. So using my 14 yo math skill, the factor,
a**(dim/(dim-2))
is just a constant
8**(128/126)
The angle theta now have an additional factor of 8**(-2(i-1)/126).
For i=1, the factor is 1. for i=d/2=64, the factor is 1/8.
Perfect!
disperaller@reddit
Hi sir, the math is quite straight forward. However, according to this post, somehow i got confused again (gradientai/Llama-3-8B-Instruct-262k · Hugging Face). Gradient, they use NTK-aware to expand llama3 8b from 8k to 65k, then from 65k to 262k, the rope theta they used to expand from 8k to 65k is 15.3 million, don't see how this 15.3 million calculates since 65 / 8 should be about 8, and times the original llama3 rope theta (0.5 million), should be 4m instead of 15.3m. I was hoping someone could help explain the math behind this, thanks in advance.