Sirius: Contextual Sparsity with Correction for Efficient LLMs



1Carnegie Mellon University, 2Stevens Institute of Technology 3Meta AI (FAIR)

  Introduction

With the blossom of large language models (LLMs), inference efficiency becomes increasingly important. Various approximation methods are proposed to reduce the cost at inference time. Contextual Sparsity (CS) is appealing for its training-free nature and its ability to reach a higher compression ratio seemingly without quality degradation. However, after a comprehensive evaluation of contextual sparsity methods on various complex generation tasks, we find that although CS succeeds in prompt-understanding tasks, CS significantly degrades the model performance for reasoning, deduction, and knowledge-based tasks. Despite the gap in end-to-end accuracy, we observed that sparse models often share general problem-solving logic and require only a minor portion of token corrections to recover the original model performance.

This paper introduces Sirius[1], an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain. Sirius is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency. Also, we carefully develop a system implementation for Sirius and show that Sirius achieves roughly 20% reduction in latency for 8B model on-chip and 35% reduction for 70B model offloading.

[1] We draw inspiration from the astronomical concept, in which Sirius refers to a two-body star system, where one is the brightest star ever detected, while the other is a dim star.

  Observation: Contextual Sparsity Limitation

Previous works focus on the existence of Contextual Sparsity in LLM and mainly evaluate CS models on classification or on simple text summarization tasks. In this paper, we evaluate CS models comprehensively on various complex generation tasks. First, we classify the CS models as follows:

  1. Coarse-grained Sparsity (CSparse) Methods ([1]) - that within the same input prompt, the sparsity pattern is fixed for all tokens generated.
  2. Fine-grained Sparsity (FSparse) Methods ([2]) - that exploits the per-token sparsity to save resources.
Illustration of CS Limitations

CS models are evaluated at their default sparsity (50% neuron sparsity). Across the evaluation, we present the following takeaways:

  1. CS models work well on prompt understanding tasks, e.g. text summarization (CNN/DailyMail) and conversation question answering (CoQA).
  2. CS models significantly ill-perform on generation tasks that require complex reasoning (GSM8K) or knowledge-based tasks (MMLU-FLAN-COT).

We show the results in the below table. Further, we contrast in above Figure (a), where if sparsity varies, the performance on CNN/DailyMail (coral) is robust, while the performance on GSM8K (green) collapses at global sparsity of 50%.

Experiment Settings Where CS Succeeds Where CS Fails
CNN/DailyMail CoQA TruthfulQA GSM8K HumanEval MMLU*
Unitxt Rouge EM/F1 Rouge-1/2 ACC ACC (strict/flexible) Pass@1 (GD) Accuracy
Llama-3-8B-Instruct 0.1237 0.6153/0.7825 0.4945/0.3647 0.7551/0.7544 0.560 0.6231
Llama-3-8B-Instruct-CSparse 0.1144 0.6633/0.7977 0.4725/0.3403 0.3859/0.3874 0.207 0.5558
Llama-3-8B-Instruct-FSparse 0.1166 0.6625/0.7984 0.5043/0.3305 0.5868/0.5891 0.457 0.5304
Llama-2-7B-Chat 0.1489 0.5982/0.7580 0.4480/0.3831 0.2396/0.2462 0.140 0.492
Llama-2-7B-Chat-FSparse 0.1448 0.6117/0.7639 0.4529/0.3843 0.1334/0.1380 0.067 0.4637
Llama-2-7B-Chat-FSparse 0.1521 0.5898/0.7540 0.4565/0.3660 0.1979/0.2017 0.134 0.4768

* MMLU is a classification task, not generation tasks. We use MMLU-FLAN-COT.

The drastic loss in complex reasoning tasks might be because the neuron activation intensity is more complex and more difficult to be captured by CS with fixed sparsity level, as illustrated in the Figure (b). Furthermore, we study CS on 70B model in Figure (c) and show that at global sparsity lower than 50%, the performance on GSM8K-COT of the Llama-3-70B-Instruct with contextual sparsity is even worse than the full Llama-3-8B-Instruct, while having 4X the parameter size. This observation shows that CS is not usable for 70B models for complex reasoning tasks.

[1] Dong, H., Chen, B., and Chi, Y. (2024). Prompt-prompted mixture of experts for efficient llm generation.

[2] Lee, J.-Y., Lee, D., Zhang, G., Tiwari, M., and Mirhoseini, A. (2024). Cats: Contextually-aware thresholding for sparsity in large language models.

  Sirius Motivation

Sirius Results

We study in detail the cases where CS models fail. The errors are either miscalculation, wrong reasoning path, and insensible statements (refer to the paper for more examples and analysis). The mistakes always happen in the middle of the argument but propagate to the end-results. We show some examples in Figure (c).

Can the generation be corrected by just correcting these minor mistakes in the middle? We run both the full model and CS model and contrast token-by-token for Llama-3-8B-Instruct and Llama-2-7B-Chat, the results are shown in Figure (a) and (b). We found that the percentage of tokens needed to corrected is minor, with 11% tokens be modified enough to recover the full model performance. This motivates us to develop an efficient correction mechanism to boosts the CS models on complex generation tasks with reasoning.

  Sirius and Results

Method Description

Sirius Results

We propose Sirius, an efficient correction mechanism for CS models. For full description of Sirius design choices, please refer to the paper. Sirius is a period-based approach, where the full model is called much infrequently (usually 16 tokens). During correction, the full model directly rewrites KV Cache, interleave new tokens, and roll back unlikely tokens. Though Sirius is seemingly a speculative decoding look-alike, we conduct rigorously studies to show that native Speculative Decoding causes significant efficiency limitation in the sparse correction scenario, because of 1) sparse model too large; 2) Correction criteria too strict. Moreover, a hardware-efficient tree is built to boost the efficiency of Sirius.

Correction Effectiveness

We show Sirius effectiveness and efficiency in the following table. We select GSM8K for Arithmetic Reasoning, CSQA for Commonsense Reasoning, and HumanEval for code generation. Under the "Sirius Perf." column, A(B) is shown. A denotes the accuracy after Sirius correction in the dataset evaluated, while (B) represents the optimal treewidth selected under the current model dataset settings. Under the column of "AAL", X/Y is shown, where X is the AAL, while Y is the period. "Effective Density" refers to the Average Density of the overall system. Besides, we evaluate 6 different mainstream LLMs (Llama-3-8B, Llama-2-7B, Llama-2-13B with their instruction-finetuned counterparts) on 8 different tasks. Specifically, we have GSM8K, AQuA-RAT-COT for Arithmetic Reasoning, CSQA, StrategyQA, Sports, and Dates for Commonsense Reasoning, HumanEval and MBPP+ for coding. Sirius is effective and efficient across all these different settings. For details on the evaluation, please refer to the full paper.

Llama-3-8B-Instruct with Sirius Effectiveness on Different Complex Tasks
GSM8K
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7536 0.3844 0.65 0.7051 (8) 15.22/16 0.706
Llama-3-8B 0.4966 0.2085 0.65 0.4177 (8) 15.29/16 0.703
Model Full Perf. FSparse Perf. FSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7536 0.5868 0.76 0.7278 (4) 15.37/16 0.807
Llama-3-8B 0.4966 0.3199 0.76 0.4579 (2) 15.03/16 0.825
CSQA
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7073 0.6470 0.58 0.7076 (8) 14.76/16 0.657
Llama-3-8B 0.6437 0.5585 0.58 0.6429 (8) 15.43/16 0.628
Model Full Perf. FSparse Perf. FSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7073 0.6158 0.72 0.7043 (8) 15.66/16 0.753
Llama-3-8B 0.6437 0.533 0.72 0.6388 (1) 15.00/16 0.786
HumanEval
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.561 0.207 0.65 0.524 (8) 14.67/16 0.733
Llama-3-8B 0.262 0.067 0.65 0.243 (8) 15.1074/16 0.691
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.561 0.457 0.76 0.616 (6) 15.42/16 0.804
Llama-3-8B 0.262 0.189 0.76 0.293 (6) 15.5446/16 0.797

Wallclock Speedup

We show that Sirius deliver the theoretical latency in both on-chip and offloading settings. The test dataset is GSM8K-COT. For the on-chip setting, we evaluate Sirius on Llama-3-8B-Instruct on A40, L40, A100, and H100. Sirius mostly achieves 20% reduction in wallclock latency. Also, for the 70B model, offloading partial weights on the CPU and only loading these weights into GPU memory when needed is one of the only viable ways normal practitioners can do. For PCIE bandwidth of 25GB/s, we show that Sirius can achieve 35% reduction in wallclock latency for Llama-3-70B-Instruct.

Llama-3-70B-Instruct with Offloading
Settings Sparse Sirius Full
Performance 0.7407 0.8719 0.9014
Latency (s) 3.57 s 3.68 s 5.72 s
Ratio to Full 0.6241 0.6434 1.0
Llama-3-8B-Instruct On-Chip Wallclock Latency Speedup
Settings Performance A40 Speedup Ratio L40 Speedup Ratio Performance A100 Speedup Ratio H100 Speedup Ratio
Coarse-grained Sparsity 0.3601 20.7 0.85 15.6 0.67 0.3601 9.6 0.72 6.6 0.76
Sirius 0.7127 24.1 0.77 18.2 0.78 0.7089 11.1 0.83 7.7 0.88
Full 0.7612 30.9 1.0 23.2 1.0 0.7612 13.3 1.0 8.6 1.0

  Conclusion

We observe that contextual sparse methods significantly degrade for reasoning and deduction tasks. However, we find that the degradation from contextual sparse models can theoretically be recovered with 11% token corrected by original model. Following the observation, we develop Sirius. Sirius provides an effective solution to the performance degradation issue of contextual sparsity methods in complex reasoning tasks. By introducing an efficient correction mechanism, Sirius significantly boosts the performance of CS models while maintaining their efficiency gains. This work opens up new possibilities for deploying efficient LLMs in resource-constrained environments without compromising on task performance.

<i>Sirius</i>

BibTeX

@article{hippocampus2024sirius,
  title={Sirius: Contextual Sparsity with Correction for Efficient LLMs},
  author={Hippocampus, David S. and Zhou, Yang and Chen, Zhuoming and Xu, Zhaozhuo and Lin, Victoria and Chen, Beidi},
  journal={arXiv preprint arXiv:2404.11912},
  year={2024}
}