CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How's the speed droping when length get large compare with vanilla llama?

lucasjinreal opened this issue · comments

How's the speed droping when length get large compare with vanilla llama?

Our method should be faster than Hugging Face LLaMA as it uses extended context only in 3 (out of 26 in the case of the 3B model) layers. For example, for the 8k context, LLaMA takes ~3.3s to process it and generate additional 12 tokens, whereas LongLLaMA only ~1.3s (numbers from 40GB A100 GPU with bfloat16). For 16k context, I have got OOM on the LLaMA code (I have not used any special optimizations to fit the model here). In #7 you can find additional numbers regarding the speed of LongLLaMA.

@CStanKonrad how about compare with long chat, which recently added in vicuna, does there any pros and cons compare with it?

We haven't compared inference time to longchat since we haven't tried 7b/13b longllama models - they are yet to come. The pros of using our approach is that long context is processed only in a subset of layers (3 out of 26 or 32 for 7b) which means that the effective compute dedicated to attention layers is around 10x smaller than in longchat. This should make inference time nearly proportional to the total prompt length, if your unit is one llama context window (2048 tokens), and you consider lengths up to 100k tokens. A potential drawback would be that our model is less expressive as it cannot process long context in its every layer, only in a subset of them. It is still unclear whether this causes large performance gains, and we haven't observed such, at least in perplexity measurements, which motivated this design choice.

We haven't compared inference time to longchat since we haven't tried 7b/13b longllama models - they are yet to come. The pros of using our approach is that long context is processed only in a subset of layers (3 out of 26 or 32 for 7b) which means that the effective compute dedicated to attention layers is around 10x smaller than in longchat. This should make inference time nearly proportional to the total prompt length, if your unit is one llama context window (2048 tokens), and you consider lengths up to 100k tokens. A potential drawback would be that our model is less expressive as it cannot process long context in its every layer, only in a subset of them. It is still unclear whether this causes large performance gains, and we haven't observed such, at least in perplexity measurements, which motivated this design choice.

@syzymon thanks for the explanation. I wander what will happen if llama process long input in the same way as longllama:

"long context is processed only in a subset of layers (3 out of 26 or 32 for 7b)"

Since there are no more experimental details in FoT paper, I don't know whther this is the way what contrast experiments adopt or not. thx~

Roughly speaking, both at training and inference time LongLLaMA uses only around 10% of the layers for long context. This means, we save about 80-90% of FLOPs spent on attention. When context length is really large (think: 32K or 128K), this can amount to multi-fold speed improvement per token, as computation cost per token is linear with respect to the number of times you increase your context length by. I imagine you could expect at least 2x, 3x or 5x speedup in practice for 32K and above, but haven't done experiment yet.

I'm working on inference time comparison between LongLLaMA 7B and vanilla LLaMA 7B, in pytorch huggingface. I will also release analysis of training time soon. Stay tuned for the results!

Roughly speaking, at inference time LongLLaMA uses only around 10% of the layers for long context. This means, we save about 80-90% of FLOPs spent on attention. When context length is really large (think: 32K or 128K), this can amount to multi-fold speed improvement per token, as computation cost per token is linear with respect to the number of times you increase your context length by. I imagine you could expect at least 2x, 3x or 5x speedup in practice for 32K and above, but haven't done experiment yet.

I'm working on inference time comparison between LongLLaMA 7B and vanilla LLaMA 7B, in pytorch huggingface. I will also release analysis of training time soon. Stay tuned for the results!

is this comparison performed under the same condition? I mean LLaMA 7B also processes long input one by one and caches previous results in the same layers, although LLaMA did not do anything to this layers during training, it can still be applied to this procedure.

I'm not sure I understand the question. The experiment I will do is the following: we take input consisting of 128K tokens and feed it into both LLaMA 7B and LongLLaMA 7B, then generate small number of tokens (like 256) from both model. I will measure wall clock time of this for both model. You can do this experiment yourself just taking LLaMA 7B checkpoint and using huggingface codebase, we do not change model architecture. Obviously the model is not trained to utilize these layers correctly, but you can perform time comparison regardless of that, you can even take random model parameters to check inference time and it will be the same. 我希望我理解了你的问题

I'm not sure I understand the question. The experiment I will do is the following: we take input consisting of 128K tokens and feed it into both LLaMA 7B and LongLLaMA 7B, then generate small number of tokens (like 256) from both model. I will measure wall clock time of this for both model. You can do this experiment yourself just taking LLaMA 7B checkpoint and using huggingface codebase, we do not change model architecture. Obviously the model is not trained to utilize these layers correctly, but you can perform time comparison regardless of that, you can even take random model parameters to check inference time and it will be the same. 我希望我理解了你的问题

My point is, I don't think the speed is a advantage or disadvantage of LongLLaMA during evaluation since the speed is related to how model handle long inputs. As described in paper, LLaMA checkpoint can also be loaded by LongLLaMA code, and this will result in the same speed despite LLaMA may produce awful results. The way handle long inputs in LongLLaMA isn't something new. I have conducted some preliminary experiments by use LongLLaMA code to load LLaMA checkpoint, found LongLLaMA indeed outperform LLaMA in task passkey (which is exciting!), but they are tied in other tasks like long document summarization under the zero-shot setting.

What longllama checkpoint did you use? (there is base v1, base v1.1 and instruct) I agree longllama is a research preview and is not as competitive as closed source model (gpt3.5 or claude) in downstream task. We are currently working on releasing more powerful models and expect better performance.

However, I respectfully disagree with your statement that speed is not an advantage. For really long inputs, processing 128K token input with standard llama 7B might be infeasible. By using LongLLaMA which maintains only ~10% of the long range computation it could become possible to run this quickly on a standard single GPU like A100.

Also, could you please point out why you think "The way handle long inputs in LongLLaMA isn't something new.", apart from Memorizing Transformers which is a basis of our work? I think using a subset of layers for efficiency reasons is a distinct contribution of our work. I will be willing to cite any other related work if you provide it.

What longllama checkpoint did you use? (there is base v1, base v1.1 and instruct) I agree longllama is a research preview and is not as competitive as closed source model (gpt3.5 or claude) in downstream task. We are currently working on releasing more powerful models and expect better performance.

However, I respectfully disagree with your statement that speed is not an advantage. For really long inputs, processing 128K token input with standard llama 7B might be infeasible. By using LongLLaMA which maintains only ~10% of the long range computation it could become possible to run this quickly on a standard single GPU like A100.

Also, could you please point out why you think "The way handle long inputs in LongLLaMA isn't something new.", apart from Memorizing Transformers which is a basis of our work? I think using a subset of layers for efficiency reasons is a distinct contribution of our work. I will be willing to cite any other related work if you provide it.

Sorry, I apology for the loose statement of "The way handle long inputs in LongLLaMA isn't something new." I think "using subset of layers" is a natural extension to Memorizing Transformers. However, this speedup procedure is not attactive for me if the base model openllama has the same performance compared to longllama when openllama use the same procedure, so I performed some tests. I roughly tested two models (longllama v1 and openllama) under long input (context length > 2048) condition, and I found that longllama indeed perform much better than openllama in task passkey. Regarding to other tasks like summarization, two models performed almost the same.

I appreciate your work about longllama, it really opens a new way for me to tackle real tasks involve long input such as summarization, qa, text-to-table geneation, etc. I'm eagerly look forward to powerful models and more experiments in other long input generation tasks in the future.

Thank you for the clarification. Could you be more specific which checkpoint you have tried? I am afraid llama 3b base model does not have much of summarization capabilities, and I also tried few-shot summarization which unfortunately did not work. Have you tried long_llama_3b_instruct? https://huggingface.co/syzymon/long_llama_3b_instruct It should be capable of some zero-shot summarization which we test in our colab notebook: https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_instruct_colab.ipynb