With the blossom of large language models (LLMs), inference efficiency becomes increasingly important. Various approximation methods are proposed to reduce the cost at inference time. Contextual Sparsity (CS) is appealing for its training-free nature and its ability to reach a higher compression ratio seemingly without quality degradation. However, after a comprehensive evaluation of contextual sparsity methods on various complex generation tasks, we find that although CS succeeds in prompt-understanding tasks, CS significantly degrades the model performance for reasoning, deduction, and knowledge-based tasks. Despite the gap in end-to-end accuracy, we observed that sparse models often share general problem-solving logic and require only a minor portion of token corrections to recover the original model performance.
This paper introduces Sirius[1], an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain. Sirius is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency. Also, we carefully develop a system implementation for Sirius and show that Sirius achieves roughly 20% reduction in latency for 8B model on-chip and 35% reduction for 70B model offloading.
[1] We draw inspiration from the astronomical concept, in which Sirius refers to a two-body star system, where one is the brightest star ever detected, while the other is a dim star.
Previous works focus on the existence of Contextual Sparsity in LLM and mainly evaluate CS models on classification or on simple text summarization tasks. In this paper, we evaluate CS models comprehensively on various complex generation tasks. First, we classify the CS models as follows:
CS models are evaluated at their default sparsity (50% neuron sparsity). Across the evaluation, we present the following takeaways:
We show the results in the below table. Further, we contrast in above Figure (a), where if sparsity varies, the performance on CNN/DailyMail (coral) is robust, while the performance on GSM8K (green) collapses at global sparsity of 50%.
Experiment Settings | Where CS Succeeds | Where CS Fails | ||||
---|---|---|---|---|---|---|
CNN/DailyMail | CoQA | TruthfulQA | GSM8K | HumanEval | MMLU* | |
Unitxt Rouge | EM/F1 | Rouge-1/2 ACC | ACC (strict/flexible) | Pass@1 (GD) | Accuracy | |
Llama-3-8B-Instruct | 0.1237 | 0.6153/0.7825 | 0.4945/0.3647 | 0.7551/0.7544 | 0.560 | 0.6231 |
Llama-3-8B-Instruct-CSparse | 0.1144 | 0.6633/0.7977 | 0.4725/0.3403 | 0.3859/0.3874 | 0.207 | 0.5558 |
Llama-3-8B-Instruct-FSparse | 0.1166 | 0.6625/0.7984 | 0.5043/0.3305 | 0.5868/0.5891 | 0.457 | 0.5304 |
Llama-2-7B-Chat | 0.1489 | 0.5982/0.7580 | 0.4480/0.3831 | 0.2396/0.2462 | 0.140 | 0.492 |
Llama-2-7B-Chat-FSparse | 0.1448 | 0.6117/0.7639 | 0.4529/0.3843 | 0.1334/0.1380 | 0.067 | 0.4637 |
Llama-2-7B-Chat-FSparse | 0.1521 | 0.5898/0.7540 | 0.4565/0.3660 | 0.1979/0.2017 | 0.134 | 0.4768 |
* MMLU is a classification task, not generation tasks. We use MMLU-FLAN-COT.
The drastic loss in complex reasoning tasks might be because the neuron activation intensity is more complex and more difficult to be captured by CS with fixed sparsity level, as illustrated in the Figure (b). Furthermore, we study CS on 70B model in Figure (c) and show that at global sparsity lower than 50%, the performance on GSM8K-COT of the Llama-3-70B-Instruct with contextual sparsity is even worse than the full Llama-3-8B-Instruct, while having 4X the parameter size. This observation shows that CS is not usable for 70B models for complex reasoning tasks.
[1] Dong, H., Chen, B., and Chi, Y. (2024). Prompt-prompted mixture of experts for efficient llm generation.
[2] Lee, J.-Y., Lee, D., Zhang, G., Tiwari, M., and Mirhoseini, A. (2024). Cats: Contextually-aware thresholding for sparsity in large language models.
We study in detail the cases where CS models fail. The errors are either miscalculation, wrong reasoning path, and insensible statements (refer to the paper for more examples and analysis). The mistakes always happen in the middle of the argument but propagate to the end-results. We show some examples in Figure (c).
Can the generation be corrected by just correcting these minor mistakes in the middle? We run both the full model and CS model and contrast token-by-token for Llama-3-8B-Instruct and Llama-2-7B-Chat, the results are shown in Figure (a) and (b). We found that the percentage of tokens needed to corrected is minor, with 11% tokens be modified enough to recover the full model performance. This motivates us to develop an efficient correction mechanism to boosts the CS models on complex generation tasks with reasoning.
We propose Sirius, an efficient correction mechanism for CS models. For full description of Sirius design choices, please refer to the paper. Sirius is a period-based approach, where the full model is called much infrequently (usually 16 tokens). During correction, the full model directly rewrites KV Cache, interleave new tokens, and roll back unlikely tokens. Though Sirius is seemingly a speculative decoding look-alike, we conduct rigorously studies to show that native Speculative Decoding causes significant efficiency limitation in the sparse correction scenario, because of 1) sparse model too large; 2) Correction criteria too strict. Moreover, a hardware-efficient tree is built to boost the efficiency of Sirius.
We show Sirius effectiveness and efficiency in the following table. We select GSM8K for Arithmetic Reasoning, CSQA for Commonsense Reasoning, and HumanEval for code generation. Under the "Sirius Perf." column, A(B) is shown. A denotes the accuracy after Sirius correction in the dataset evaluated, while (B) represents the optimal treewidth selected under the current model dataset settings. Under the column of "AAL", X/Y is shown, where X is the AAL, while Y is the period. "Effective Density" refers to the Average Density of the overall system. Besides, we evaluate 6 different mainstream LLMs (Llama-3-8B, Llama-2-7B, Llama-2-13B with their instruction-finetuned counterparts) on 8 different tasks. Specifically, we have GSM8K, AQuA-RAT-COT for Arithmetic Reasoning, CSQA, StrategyQA, Sports, and Dates for Commonsense Reasoning, HumanEval and MBPP+ for coding. Sirius is effective and efficient across all these different settings. For details on the evaluation, please refer to the full paper.
GSM8K | ||||||
---|---|---|---|---|---|---|
Model | Full Perf. | CSparse Perf. | CSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.7536 | 0.3844 | 0.65 | 0.7051 (8) | 15.22/16 | 0.706 |
Llama-3-8B | 0.4966 | 0.2085 | 0.65 | 0.4177 (8) | 15.29/16 | 0.703 |
Model | Full Perf. | FSparse Perf. | FSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.7536 | 0.5868 | 0.76 | 0.7278 (4) | 15.37/16 | 0.807 |
Llama-3-8B | 0.4966 | 0.3199 | 0.76 | 0.4579 (2) | 15.03/16 | 0.825 |
CSQA | ||||||
Model | Full Perf. | CSparse Perf. | CSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.7073 | 0.6470 | 0.58 | 0.7076 (8) | 14.76/16 | 0.657 |
Llama-3-8B | 0.6437 | 0.5585 | 0.58 | 0.6429 (8) | 15.43/16 | 0.628 |
Model | Full Perf. | FSparse Perf. | FSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.7073 | 0.6158 | 0.72 | 0.7043 (8) | 15.66/16 | 0.753 |
Llama-3-8B | 0.6437 | 0.533 | 0.72 | 0.6388 (1) | 15.00/16 | 0.786 |
HumanEval | ||||||
Model | Full Perf. | CSparse Perf. | CSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.561 | 0.207 | 0.65 | 0.524 (8) | 14.67/16 | 0.733 |
Llama-3-8B | 0.262 | 0.067 | 0.65 | 0.243 (8) | 15.1074/16 | 0.691 |
Model | Full Perf. | CSparse Perf. | CSparse Density | Sirius Perf. | AAL | Effective Density |
Llama-3-8B-Instruct | 0.561 | 0.457 | 0.76 | 0.616 (6) | 15.42/16 | 0.804 |
Llama-3-8B | 0.262 | 0.189 | 0.76 | 0.293 (6) | 15.5446/16 | 0.797 |
Settings | Performance | A40 | Speedup Ratio | L40 | Speedup Ratio | Performance | A100 | Speedup Ratio | H100 | Speedup Ratio |
---|---|---|---|---|---|---|---|---|---|---|
Coarse-grained Sparsity | 0.3601 | 20.7 | 0.85 | 15.6 | 0.67 | 0.3601 | 9.6 | 0.72 | 6.6 | 0.76 |
Sirius | 0.7127 | 24.1 | 0.77 | 18.2 | 0.78 | 0.7089 | 11.1 | 0.83 | 7.7 | 0.88 |
Full | 0.7612 | 30.9 | 1.0 | 23.2 | 1.0 | 0.7612 | 13.3 | 1.0 | 8.6 | 1.0 |
We observe that contextual sparse methods significantly degrade for reasoning and deduction tasks. However, we find that the degradation from contextual sparse models can theoretically be recovered with 11% token corrected by original model. Following the observation, we develop Sirius. Sirius provides an effective solution to the performance degradation issue of contextual sparsity methods in complex reasoning tasks. By introducing an efficient correction mechanism, Sirius significantly boosts the performance of CS models while maintaining their efficiency gains. This work opens up new possibilities for deploying efficient LLMs in resource-constrained environments without compromising on task performance.
@article{hippocampus2024sirius,
title={Sirius: Contextual Sparsity with Correction for Efficient LLMs},
author={Hippocampus, David S. and Zhou, Yang and Chen, Zhuoming and Xu, Zhaozhuo and Lin, Victoria and Chen, Beidi},
journal={arXiv preprint arXiv:2404.11912},
year={2024}
}