With LLMs widely deployed in long content generation recently, KV cache has emerged as a critical bottleneck by growing linearly in size with the sequence length (e.g., Llama2-7B-128K has 64GB KV cache and 14GB model weights). We present TriForce, a scalable and robust hierarchical speculative decoding system that enables serving long-context LLMs (Llamma2-7B-128K, LWM-Text-Chat-128K, Llama2-13B-128K, etc.) for long sequence generation with 0.1s latency per token on consumer GPUs losslessly (16-bit precision, preserving the original output distribution). We demonstrate that TriForce can efficiently serve a Llama2-13B with 128K contexts on two RTX 4090s, reaching an average time between tokens (TBT) as low as 0.22 seconds, which is 7.8x faster than a highly optimized offloading system. Furthermore, with TriForce, Llama2-7B-128K can be served on two RTX 4090s with a TBT of 0.11s—only 0.5x slower than on one A100. Additionally, TriForce performs 4.86x than DeepSpeed-Zero-Inference on a single RTX 4090 GPU. Apart from offloading, TriForce provides an on-chip solution for data-center GPUs like A100, which is discussed in detail in our paper.
TriForce enhances the efficiency of generating long sequences across a range of models. Our evaluation of TriForce includes LLMs such as Llama2-7B-128K, LWM-Text-Chat-128K and Llama2-13B-128K on RTX4090s, prompted by PG-19 and NarrativeQA. The entries marked with an asterisk represent the baseline using DeepSpeed-ZeRO-Inference. The official implementation of DeepSpeed-ZeRO-Inference with KV cache offloading currently only supports a single GPU, which computes attention on CPU. Our offloading system transfers KV cache from CPU to GPU, benefiting from Tensor Parallelism.
GPU | Target Model | TriForce (ms) | Baseline (ms) | Speedup |
---|---|---|---|---|
2x 4090s | Llama2-7B-128K | 108 | 840 | 7.78x |
2x 4090s | LWM-Text-Chat-128K | 114 | 840 | 7.37x |
2x 4090s | Llama2-13B-128K | 226 | 1794 | 7.94x |
1x 4090 | Llama2-7B-128K | 312 | 1516* | 4.86x |
1x 4090 | LWM-Text-Chat-128K | 314 | 1516* | 4.83x |
Here we present a demo for LWM-Text-Chat-128K inference on two RTX 4090s with 127K contexts (with and without TriForce). We prefill the model with 127K tokens from a book in NarrativeQA, directing the model to summarize the book's content. The video is displayed at normal speed (1x).
TriForce effectively addresses the challenge while provably preserving model quality by integrating retrieval-based drafting and hierarchical speculation. This approach leverages the original model weights and a small proportion of KV cache from retrieval as a draft model, which is further speculated by a lightweight model with StreamingLLM cache to reduce drafting latency. By mitigating the dual bottlenecks associated with KV cache and model weights, it significantly accelerates long-context LLM serving with offloading.
Moreover, in our paper, we show that: (1) TriForce is scalable with longer contexts. This scalability is attributed to its high acceptance rate and the growing gap between the draft and the target model's latencies since we keep the constant KV cache budget for drafting; (2) TriForce is robust in terms of sampling temperatures, maintaining an acceptance rate above 0.9 even when the temperature is 1.0.
As the figure illustrates, for a long-context target model (e.g., Llama2-7B-128K), we leverage the original model weights but only with a small proportion (e.g., 3%) of KV cache as a draft to tackle the bottleneck of KV cache. Hierarchically, the draft model is further speculated by a lightweight model (e.g., Llama-68M) with StreamingLLM cache to address the bottleneck of model weights. Therefore, TriForce integrates two models and three caches, comprising a draft model, a target model, a StreamingLLM cache for the draft model, alongside a retrieval cache and a full cache for the target model. The process initiates by repeatedly drafting for steps, assisting the target model with retrieved partial KV cache in generating over tokens, which will be further verified by target model using full KV cache.
Our design of TriForce is inspired by three critical empirical observations regarding LLMs when dealing with long contexts, detailed as follows.
As shown in the figure below, the Llama2-7B-128K model demonstrates significant attention sparsity with a 120K context. We observe that with a context length of 120K, it is possible to recover over 96% of the attention score with merely 4K tokens across almost all layers. The presence of sparsity within the attention blocks suggests that a fraction of KV cache could serve as a draft cache to attain a high acceptance rate during self-speculative decoding.
The necessity of keeping the entire KV cache in our settings allows us to select KV cache freely. In our approach, KV cache is segmented into small chunks. During the retrieval phase, we calculate the attention between a given query and the average key cache within each chunk. This method effectively highlights the most relevant chunks, enabling us to gather KV cache with a fixed budget based on the scores. By focusing on relevance over recency, retrieval-based policy demonstrates its potential to handle contextually dense datasets.
Our exploration reveals that the information from long context tokens needed by adjacent tokens tends to be similar. With the context length established at 120K, we instruct the model to generate 256 tokens. By choosing the top-4K indices according to the attention scores of the last prefilled token, we use these indices to gather attention scores for the subsequently generated tokens and assess the score's recovery rate for the initially prefilled 120K tokens. It leads to high recovery across most layers and a slowly decreasing trend as the number of tokens increases.
This observation allows for a single construction of the cache to suffice for multiple decoding steps, thereby amortizing the latency of constructing draft cache and boosting efficiency. As new KV cache are introduced, guided by the understanding that recent words are more strongly correlated with the tokens currently being decoded, these entries will replace the less significant ones. Cache re-building operations can be scheduled at regular intervals or adaptively in response to a drop in the acceptance rate, which ensures that the cache remains dynamically aligned with the evolving context.
While addressing the KV cache bottleneck enhances efficiency, the requirement to load whole model weights for drafting reintroduces latency, shifting the bottleneck to model weights again. To tackle this challenge, we implement a hierarchical system. This system employs a secondary, lightweight model with StreamingLLM cache to perform initial speculations for our target model with retrieval-based draft cache (which serves as a draft model for the target model with full KV cache). By establishing this sequential speculation hierarchy, we effectively reduce the latency associated with drafting, thereby accelerating the overall inference.
Leveraging the TriForce framework, anyone can host a chatbot capable of processing long texts up to 128K or even 1M tokens without approximation on consumer GPUs, such as the RTX 4090, making long-context LLMs more accessible to a wide audience. TriForce can be further deployed on robots, expanding their ability to understand and interact using long-context conversations. Additionally, it can be further integrated with various works on KV compression (e.g., KV quantization), enhancing its performance. Our hierarchical speculative decoding algorithm is specifically designed to be highly adaptable, catering to the diverse and evolving memory hierarchies of future hardware. TriForce precisely bridges memory hierarchy gaps, adapting alongside the hardware community to optimize performance.
@article{sun2024triforce,
title={Triforce: Lossless acceleration of long sequence generation with hierarchical speculative decoding},
author={Sun, Hanshi and Chen, Zhuoming and Yang, Xinyu and Tian, Yuandong and Chen, Beidi},
journal={arXiv preprint arXiv:2404.11912},
year={2024}
}