TL;DR: We introduce MagicDec, which uses Speculative Decoding to improve both throughput and latency of LLM inference. Our work identifies the bottleneck shifts with increasing batch size and sequence length, and uses these insights to deploy Speculative Decoding more effectively for high throughput inference. It challenges the existing wisdom regarding the inefficacy of Speculative Decoding for large batch sizes. More interestingly, we observe an improvement in speedup with increasing batch size for moderate to long sequences. Our work theoretically motivates and empirically validates why Speculative Decoding is a potential solution for breaking throughput-latency tradeoff, when used wisely.
With an emergence of long-context models and a growing popularity of applications like interactive chat-bots, efficiently serving long context requests has gained significant attention. Concurrently, for commercial use-cases, high-throughput serving directly translates to increased revenue. However, achieving both low latency and high throughput poses a challenge as the two are often at odds with each other. While Speculative Decoding techniques have been proposed to improve latency at the cost of throughput, different continuous batching techniques have prioritized throughput over latency. Quantization and pruning based methods can achieve both high throughput and latency, but often sacrifice performance. Given these trade-offs, we pose the following question:
Can we achieve high throughput and low latency for long-context LLM generation without sacrificing performance?
To answer this question, we revisit the efficacies of Speculative Decoding through an analytical approach. Existing works (e.g., Liu et al, 2024; Su et al, 2023) have claimed that although useful for small batch sizes, Speculative Decoding can become counter-productive for serving large batches. This inefficiency stems from two primary factors:
In this blogpost, we look into these claims more carefully and try to find way-arounds. With suitable drafting strategies,
The key observation is that, for every model and hardware pair, there exists a critical sequence length beyond which LLM inference becomes memory bound even for large batch sizes. This is because, for long enough sequences, the KV cache becomes the primary memory bottleneck and unlike model parameters, it scales with batch size. Although computation time also increases linearly with batch size, the increase in KV loading time dominates. We use this insight to -
KV sparsification for drafting is absolutely important because otherwise, in a high batch size and sequence length regime, the draft to target cost ratio will be decided by the draft to target memory footprint ratio only. For example, when we speculate Llama-3.1-70B model with a Llama-3.1-8B draft, the draft to target memory loading ratio goes up to 0.4 and stays constant at larger batch sizes, as can be seen in Figure 4. However, with a fixed KV budget for the draft model, draft latency increases very slowly with batch size. Consequently, the draft to target cost ratio keeps on decreasing for larger batches, resulting in a higher speed-up.
The fixed draft KV budget naturally raises concern about the acceptance rate, which is another important deciding factor. Fortunately, StreamingLLM with a small KV budget like 512 is able to retain a high acceptance rate even for sequences as long as 100K! To further improve the acceptance rate, we can increase the drafting budget in different ways as long as the draft to target cost ratio remains reasonably small. Contrary to existing beliefs (e.g., Liu et al, 2024; Su et al, 2023), this high acceptance rate allows us to use longer speculation length for larger batches. And thus, as you guessed, we can keep on getting better speedup with increasing batch size 💪.
Seems like Magic? Well, let's get into the details of MagicDec! 🧙♂️
Given a speculation length γ for a sequence of length S and batch size B, let TT(B, S, 1) and TD(B, S, 1) denote the target and draft decoding latencies, respectively. The verification time, TV(B, S, γ), is the time it takes the target model to verify the γ speculated tokens in a single forward pass. Given an acceptance rate α ∈ [0, 1] and speculation length γ, Ω(γ, α) represents the expected number of tokens generated in one verification step, as described by Leviathan et al.:
The total time taken for Speculative Decoding, TSDTotal, is given by:
The expected per-token latency for Speculative Decoding is simply:
The ratio of Speculative Decoding latency to regular autoregressive decoding latency can be broken down as:
We shall come back to this equation shortly.
It is well understood that Speculative Decoding is beneficial when LLM inference is memory-bounded. The following two observations are key to understanding why Speculative Decoding is useful even for large batches:
The critical sequence length is dependent on model type and hardware specifications. As can be seen in Figure 2-a, for GQA type models, the critical sequence length is expected to be higher due to their relatively lighter KV memory footprint. On the other hand, for GPUs with higher peak_flops / memory_bandwidth ratio, the critical point arrives early.
For moderately small sequence lengths and small batch sizes, LLM inference is essentially bottlenecked by parameter loading time. As batch size increases, the arithmetic intensity of the linear projections increases, making the model compute-bound. On the other hand, with increasing sequence length, KV loading time starts to become the primary memory bottleneck.
Interestingly, unlike model parameters, KV cache size increases linearly with batch size. Hence, once the KV cache becomes larger than the model memory footprint, both the memory loading time and computation time increase linearly with batch size. However, modern accelerators have much higher peak FLOPs per second than memory bandwidth. Consequently, the increase in memory loading time is more than the increase in computation time. This makes LLM inference more memory-bound in the high-throughput regime when the sequence length is long enough.
Now we dig deeper into the speedup breakdown in Eq. 1 to understand how we can achieve higher speedup with bigger batches:
With increasing batch size, verification of multiple tokens is expected to be significantly costlier than decoding as the computation time increases linearly with speculation length. However, when KV loading becomes the dominant bottleneck, the increase in KV loading time dominates the increase in computation time, restricting the verification to decoding cost ratio. As illustrated in Figure 3, the cost ratio actually saturates to a value close to 1, especially for longer sequences.
We choose to use drafts with a fixed small KV budget for speculation to address the KV bottleneck. Otherwise in a high batch size and sequence length regime, the draft to target cost ratio will be decided by the draft to target memory loading ratio only. From Figure 4., we can see that this ratio increases with batch size and then saturates for both Llama2-7B/Llama2-70B and Llama-3.1-8B/Llama-3.1-70B draft-target pairs. To improve speedup with batchsize, we want this ratio to go down instead.
Taking inspiration from Triforce, we fix a small KV budget for the draft model such that the draft cost does not increase sharply with batchsize. As a result, the draft to target cost ratio decreases with increasing batch size, as shown in Figure 5. Furthermore, given a fixed KV budget, the ratio naturally goes down more rapidly for longer sequence lengths.
To ensure a high acceptance rate for our draft with sparse KV, we use the StreamingLLM method. Interestingly, we find this simple method to be quite effective even for very long sequences. On PG-19 dataset, we find the acceptance rate of Llama-3.1-8B self-speculation with a streamingLLM budget of 512 to be relatively stable for sequence lengths ranging from 4000 to 100000.
Sequence Length | 4000 | 8000 | 16000 | 32000 | 64000 | 100000 |
---|---|---|---|---|---|---|
Acceptance Rate | 0.84 | 0.83 | 0.82 | 0.81 | 0.79 | 0.79 |
With a high acceptance rate and a reasonable verification cost, we can afford to use a longer speculation length. Contrary to previous works (e.g., Liu et al, 2024; Su et al, 2023), we observe that the optimal speculation length goes up with increasing batch size.
We evaluate the speed-up of MagicDec on Llama-2-7B-32K and LlaMA-3.1-8B-128K at different batch sizes and sequence lengths. The test sequences are sampled from the PG-19 dataset. For drafting, we have explored the following options:
Target | Draft | Prefill | Batch Size | Optimal Spec Len | Speedup |
---|---|---|---|---|---|
Llama3.1-8B | Selfspec | 32000 | 32 | 3 | 1.22 |
Llama3.1-8B | Selfspec | 32000 | 64 | 3 | 1.38 |
Llama3.1-8B | Selfspec | 32000 | 128 | 4 | 1.47 |
Llama2-7B | Selfspec | 8000 | 32 | 3 | 1.18 |
Llama2-7B | Selfspec | 8000 | 64 | 3 | 1.48 |
Llama2-7B | Selfspec | 8000 | 128 | 4 | 1.63 |
Llama2-7B | Selfspec | 32000 | 32 | 4 | 2.0 |
Llama2-7B | Tinyllama1.1B | 8000 | 32 | 3 | 1.29 |
Llama2-7B | Tinyllama1.1B | 8000 | 64 | 3 | 1.57 |
Llama2-7B | Tinyllama1.1B | 8000 | 128 | 4 | 1.66 |
Llama2-7B | Tinyllama1.1B | 32000 | 32 | 4 | 1.91 |
Target | Draft | Prefill | Batch Size | Optimal Spec Len | Speedup |
---|---|---|---|---|---|
Llama3.1-8B | Selfspec | 32000 | 32 | 3 | 1.23 |
Llama3.1-8B | Selfspec | 32000 | 64 | 3 | 1.43 |
Llama2-7B | Selfspec | 8000 | 32 | 3 | 1.19 |
Llama2-7B | Selfspec | 8000 | 64 | 4 | 1.42 |
Llama2-7B | Selfspec | 8000 | 128 | 4 | 1.65 |
Llama2-7B | Tinyllama1.1B | 8000 | 32 | 2 | 1.32 |
Llama2-7B | Tinyllama1.1B | 8000 | 64 | 3 | 1.48 |
Llama2-7B | Tinyllama1.1B | 8000 | 128 | 3 | 1.64 |
This work reassesses the trade-off between throughput and latency in long-context scenarios. We demonstrate that Speculative Decoding can enhance throughput, reduce latency, and maintain accuracy. Our theoretical and empirical analysis reveals that as the sequence length and batch size increase, bottlenecks shift from being compute-bound to memory-bound. Furthermore, we exploit the nature of the memory bottleneck to achieve even higher speedups for larger batches. MagicDec can achieve up to 2x speedup for LLaMA-2-7B-32K and 1.84x for LLaMA-3.1-8B on 8 A100 GPUs. However, the interesting observation about scaling speedup with batch size is definitely promising for high-throughput serving of long-context requests. So, stay tuned for more exciting results in the future!! 🚀
@misc{chen2024magicdecbreakinglatencythroughputtradeoff,
title={MagicDec: Breaking the Latency-Throughput Tradeoff for Long Context Generation with Speculative Decoding},
author={Jian Chen and Vashisth Tiwari and Ranajoy Sadhukhan and Zhuoming Chen and Jinyuan Shi and Ian En-Hsu Yen and Beidi Chen},
year={2024},
eprint={2408.11049},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2408.11049},
}