S2FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity

1Carnegie Mellon University 2Georgia Tech 3Caltech 4Rutgers 5UNC-Chapel Hill
xinyuya2, beidic@andrew.cmu.edu

Introduction

We introduce Structured Sparse Fine-Tuning (S2FT), the first PEFT method for LLMs that achieves high quality, efficient training, and scalable serving simultaneously. S2FT accomplishes this by “selecting sparsely and computing densely”. It selects a few heads and channels in the MHA and FFN modules for each Transformer block, respectively. Next, it co-permutes weight matrices on both sides of the coupled structures in LLMs to connect the selected components in each layer into a dense submatrix. Finally, S2FT performs in-place gradient updates on all submatrices. Through theoretical analysis and empirical results, our method prevents overfitting and forgetting, delivers SOTA performance on both commonsense and arithmetic reasoning with 4.6% and 1.3% average improvements compared to LoRA, and surpasses full FT by 11.5% when generalizing to various domains after instruction tuning. Using our partial backpropagation algorithm, S2FT saves training memory up to and improves latency by 1.5-2.7× compared to full FT, while delivering about 10% improvement over LoRA on both metrics. We further demonstrate that the weight updates in S2FT can be decoupled into adapters, enabling effective fusion, fast switch, and efficient parallelism when serving multiple fine-tuned LLMs.

Why S2FT ?

High Quality
Efficient Training
Scalable Serving
ID OOD Time Memory Fusion Switch Parallelism
Full FT ✔✔
LoRA
S2FT ✔✔ ✔✔ ✔✔ ✔✔ ✔✔ ✔✔

Compared to LoRA, S2FT offers several key advantages: (i) improved OOD performance, (2) enhanced training efficiency (time & memory), (3) better serving scalability (adapter fusion/switch/parallelism). These features are particularly valuable in real-world PEFT scenarios, where the goal is to effectively combine knowledge from various domains with the base model's capabilities using limited resource.

Observation

When generalizing to complex reasoning tasks, the performance ranking emerges as: Sparse Fine-tuning (SpFT) > Full FT > LoRA. SpFT effectively transfers reasoning abilities to commonsense domains, while LoRA exhibits significant performance drops in far OOD generalization. This indicates (i) freezing a larger fraction of base model parameters can retain more pre-trained abilities, and (ii) approximating high-dimensional gradients with low-rank decomposition may overfit fine-tuning data and hinder the model from generalization. Since LLMs are typically pre-trained on high-quality data, SpFT emerges as the preferred choice for fine-tuning on domain-specific data of varying quality.

Robustness and Scalability

Our findings are further supported by a counterintuitive observation when selecting trainable channels in S²FT using different metrics (weight, activation, and gradients). Surprisingly, selecting channels with the smallest activations leads to improved performance, while selecting those with the largest activations/gradients degrades it. This suggests that for LLMs, it is crucial to preserve more task-relevant advanced skills during pre-training when injecting knowledge from fine-tuning data.

Task S²FT-R S²FT-W S²FT-A S²FT-S S²FT-G
Large Small Large Small Large Small Large Small
Knowledge 86.6 85.9(-0.7) 85.3(-1.3) 84.7(-1.9) 87.3(+0.7) 85.1(-1.5) 87.2(+0.6) 85.4(-1.2) 86.2(-0.4)
Arithmetic 79.6 78.4(-1.2) 78.4(-1.2) 77.1(-2.5) 80.0(+0.4) 76.8(-2.8) 79.8(+0.2) 77.8(-1.8) 79.5(-0.1)

S2FT and Results

Method Description

S<sup>2</sup>FT

First, we sparsely select a few attention heads and channels within the coupled structures of the MHA and FFN modules (see the definition of coupled structures in our paper) as the trainable parameters. Next, we co-permute the weight matrices on both sides of these structures, enabling dense gradient computation only for the selected components. While we demonstrate S2FT by selecting the same heads/channels on both sides for clarity, our approach also supports asymmetric selection strategies.

Results: High Quality

S2FT outperforms Full FT, LoRA, and DoRA on both commonsense and arithmetic reasoning tasks.

Fine-tuning LLaMA-3-8B on Commonsense Reasoning Tasks
Method #Param BoolQ PIQA SIQA HellaSwag   Wino ARC-e ARC-c OBQA Avg. ↑
Full FT 100 73.9 86.2 79.1 93.1 85.8 88.1 78.2 84.0 83.6
LoRA 0.70 70.8 85.2 79.7 92.5 84.9 88.9 78.7 84.4 82.5
DoRA 0.71 74.6 89.3 79.9 95.5 85.6 90.5 80.4 85.8 85.2
S2FT 0.70 75.0 89.0 80.7 96.5 88.0 92.5 83.4 87.8 86.6
Fine-tuning LLaMA-3-8B on Arithmetic Reasoning Tasks
Method #Param MultiArith GSM8K AddSub AQuA SingleEq SVAMP MAWPS Avg. ↑
Full FT 100 99.2 62.0 93.9 26.8 96.7 74.0 91.2 77.7
LoRA 0.70 99.5 61.6 92.7 25.6 96.3 73.8 90.8 77.2
DoRA 0.71 98.8 62.7 92.2 26.8 96.9 74.0 91.2 77.5
S2FT 0.70 99.7 65.8 93.7 31.5 97.8 76.0 92.4 79.6

Results: Efficient Training

S2FT delivers an average 10% improvement over LoRA on both training memory and time.

training efficiency

Results: Scalable Inference

S2FT can modify orthogonal low-rank spaces for different tasks, resulting in effective adpater fusion.

Fusing Commonsense and Arithmetic Adapters for LLaMA-3-8B
Task LoRA S²FT
Adapter 1 Adapter 2 Fused Adapter 1 Adapter 2 Fused
Commonsense 83.1 32.1 79.8(-3.3) 86.6 42.3 84.0(-2.6)
Arithmetic 12.0 77.2 71.6(-5.6) 12.8 79.6 75.3(-4.3)

S2FT enables scalable and efficient adapter switch and parallelism by reducing matmul operations.

scalable serving

Conclusion and Future Work

This work introduces S2FT, a novel PEFT family that is generalizable, efficient, and scalable. Compared to LoRA,S2FT improves the generalization ability on downstream tasks while reduce 10% training time and memory. Furthermore, S2FT's enables scalable serving of thousands of adapters simultaneously. These comprehensive improvements in quality, efficiency, and scalability make S2FT particularly valuable for the large-scale, real-world deployment of foundation models in various domains. Future research directions include exploring the controllability in S2FT, which enablee the separation of domain-specific knowledge into distinct parameters. This capability could significantly advance sparse training and inference techniques for LLMs, particularly in MOE architecture design.

BibTeX

@inproceedings{yang2024s2ft,
  title={S2FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity},
  author={Yang, Xinyu and Leng, Jixuan and Guo, Geyang and Zhao, Jiawei and Nakada, Ryumei and Zhang, Linjun and Yao, Huaxiu and Chen, Beidi},
  booktitle={The 38th Conference on Neural Information Processing Systems (NeurIPS)},
  year={2024}
}