Training ROI = Average downstream accuracy / Total training FLOPs.
This lets you compare “how much accuracy you buy per unit of compute.”
TL;DR: We introduce STEM: a static, token-indexed sparse architecture design that swaps FFN up-projection for a layer-local embedding lookup. Despite extreme sparsity, STEM trains stably, stores more parametric knowlege while speeding up FFN layers by 3x. More interestingly, STEM improves interpretability compared to the dense and exisiting sparse baselines and strengthens long-context scaling; across 350M–1B, yields up to ~3–4% accuracy gains across various knowledge based and reasoning downstream tasks including ARC-Challenge, OpenbookQA, GSM8K, MMLU, Big-Bench Hard, etc.
Fine-grained sparsity promises higher parametric capacity without proportional increase in per-token compute. But it often suffers training instability, load balancing, and communication overhead. This is a major obstacle for the practical deployment of fine-grained sparsity. Furthermore, the sparse components are not very interpretable and it is difficult to understand the role of each micro-expert.
Token-indexed Static sparsity emerged as a potential solution to this problem. It keeps compute path predictable (no runtime routing), enables prefetching and CPU offloading, and decouples capacity from per-token compute and cross-device communication with a relatively better training stability. Furthermore, the token-indexed nature helps the sparse model to localize its stored knowledge into respective micro-experts. However, it also compromises the model performance as the token-indexed selection can not capture contextual dependencies. Consequently, inspite of adding a large number of micro-experts, the model performance might not be improved unless the micro-experts are placed carefully in the model.
Based on the results of our ablation studies, we introduce STEM, a static, token-indexed sparse architecture design that swaps FFN up-projection for a layer-local embedding lookup. It is critical to leave the gating path in FFN unchanged to preserve the contextual ability of the model.
With the additional parametric knowledge, STEM is expected to excel in different knowledge-based downstream tasks such as ARC-Challenge, OpenbookQA, MMLU. But interestingly, STEM also outperforms the dense baseline on various reasoning-heavy tasks such as GSM8K, Big-Bench Hard, etc. Additionally, STEM illustrated better long-context scaling abilities on tasks like Needle-in-the-Haystack, LongBench, etc. Most importantly, STEM achieves these performance improvements while being more efficient and interpretable. STEM successfully localizes its knowledge into respective micro-experts which is illustrated in our knowledge injection experiments.
STEM Architecture
A standard gated FFN at layer \( \ell \) uses dense projections to expand into \( d_{\mathrm{ff}} \) and then project back to \( d \). Let \( \mathbf{W}_{\ell}^{(u)} \in \mathbb{R}^{d_{\mathrm{ff}}\times d} \) be the up-projection, \( \mathbf{W}_{\ell}^{(g)} \in \mathbb{R}^{d_{\mathrm{ff}}\times d} \) the gate projection, and \( \mathbf{W}_{\ell}^{(d)} \in \mathbb{R}^{d\times d_{\mathrm{ff}}} \) the down-projection. The baseline computes:
STEM replaces the dense up-projection \( \mathbf{W}_{\ell}^{(u)} \mathbf{x}_{\ell} \) with a token-indexed, layer-local embedding lookup. For each layer \( \ell \), STEM introduces an embedding table \( \mathbf{U}_{\ell} \in \mathbb{R}^{V \times d_{\mathrm{ff}}} \). Given token id \( t \), the layer fetches \( \mathbf{U}_{\ell}[t] \in \mathbb{R}^{d_{\mathrm{ff}}} \) and combines it with the dense gate path:
where \( \odot \) denotes elementwise multiplication. In other words, STEM keeps the gate and down-projection dense, but replaces the FFN up-projection with a lookup into \( \mathbf{U}_{\ell} \), activating \( d_{\mathrm{ff}} \) parameters conditioned directly on token \( t \).
STEM Architecture: Token-indexed embedding lookup replaces the up-projection matrix
Improved Performance
Interestingly, STEM manages to be aggressively sparse without inheriting the typical instability expected from sparse architectures. The training plot shows that the baseline and STEM curves stay smooth, while the Hash-layer Mixture-of-Experts baseline shows noticeably bumpier loss jumps. STEM avoids these spikes and keeps training well-behaved even as sparsity is pushed harder.
What makes this especially compelling is the scaling trend: replacing more FFN up-projections with STEM not only reduces training FLOPs, it can also improve ROI. For example, at 1B pretraining, STEM reaches 1.08× the baseline ROI while using fewer GFLOPs per token.
Loss vs. Training Tokens (350M)
Loss vs. Training FLOPs (350M)
A useful way to think about FFNs is as a key–value memory: the FFN creates an “address” (which memory slots to read), then the down-projection mixes the corresponding values. In a dense SwiGLU FFN, this address is produced by a learned affine map (the up-projection) and then shaped by gating.
STEM changes the story: the "address" is no longer entirely synthesized from the hidden state—each token brings its own layer-local address vector via the embedding table. Crucially, STEM embeddings show a large angular spread (low pairwise cosine similarity), which reduces cross-talk between memory slots and makes stored information easier to retrieve reliably. This geometry is one reason STEM can translate extra parameters into real knowledge gains rather than noisy redundancy.
Angular Spread
Up vs STEM
Gate*Up vs Gate*STEM
On standard downstream suites, STEM is particularly strong on knowledge-intensive benchmarks (e.g., ARC-Challenge, OpenBookQA), while also improving reasoning-centric tasks like GSM8K and MMLU after mid-training. And when you extend the context window, STEM continues to hold up, often improving as context grows, which is exactly what we hope for from a design that activates more distinct parameters over longer sequences.
| Model | #Total Params (B) | #Active Params (B) | ARC-E | ARC-C | BoolQ | PIQA | SIQA | HSwag | OBQA | Wino | Avg | #GFLOPs | ROI (norm.) |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Baseline | 1.50 | 1.50 | 66.98 | 41.88 | 64.21 | 73.44 | 44.09 | 59.65 | 39.84 | 56.48 | 55.82 | 3.00 | 1.00× |
| STEM | 6.75 | 1.41 | 65.95 | 42.03 | 61.66 | 75.00 | 44.78 | 60.37 | 45.90 | 57.34 | 56.63 | 2.83 | 1.08× |
ROI is normalized to the baseline in the paper.
| Model | ARC-E | ARC-C | BoolQ | PIQA | SIQA | HellaSwag | OBQA | Wino | Avg | GSM8K | MMLU |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Baseline | 70.78 | 42.11 | 65.84 | 72.95 | 47.13 | 60.39 | 42.97 | 57.81 | 57.50 | 44.2 | 29.92 |
| STEM | 69.78 | 44.22 | 68.54 | 74.69 | 45.65 | 61.90 | 45.70 | 57.42 | 58.49 | 46.4 | 32.38 |
| Model | BBH | MuSR | LongBench Multi-hop | LongBench Code | ||||
|---|---|---|---|---|---|---|---|---|
| <4k | 4–8k | ≥8k | <4k | 4–8k | ≥8k | |||
| Baseline | 24.87 | 35.85 | 5.72 | 6.20 | 6.19 | 45.37 | 44.64 | 41.30 |
| STEM | 27.55 | 36.38 | 10.20 | 8.63 | 7.82 | 52.68 | 52.53 | 49.60 |
| Model | 0–2k | 2–4k | 4–6k | 6–8k | 8–10k | 10–12k | 12k+ |
|---|---|---|---|---|---|---|---|
| Base | 24.0 | 23.8 | 22.1 | 22.3 | 21.9 | 21.1 | 23.5 |
| STEM | 27.6 | 27.6 | 24.4 | 22.7 | 23.0 | 21.7 | 24.2 |
STEM embeddings enable direct knowledge editing by replacing token embeddings while keeping the input text unchanged. This allows us to study how the model's factual knowledge is stored and can be manipulated.
When source and target entities have the same tokenization length, editing is straightforward: replace each source token's STEM embedding with the corresponding target token's embedding.
Example: Replacing "Spain" → "Germany" in the prompt "Country: Spain. Capital:" causes the model to generate a paragraph about Berlin instead of Madrid, even though the input text still says "Spain".
For different tokenization lengths, we use several strategies:
Example: United States → Czech Republic
Prompt: "Country: United States of America. Capital:"
After STEM replacement: Model generates a paragraph about Prague, describing it as "Czechia's capital and Europe's 10th-largest city..."
Example: United States → United Kingdom
Prompt: "Country: United States of America. Capital:"
After STEM replacement: Model generates text about London as "the world's largest financial center, home to the British Parliament..."
Example: Country → State Transfer
Source: United States (Country) → Target: California (State)
Result: Model describes Sacramento as "California's political, cultural, and economic center..."
This interpretability demonstrates that STEM embeddings act as localized knowledge stores that can be directly manipulated, providing insights into how the model organizes and retrieves factual information.
STEM is designed to scale parametric memory without paying the usual inference tax. By replacing the dense up-projection with a token-indexed embedding lookup (while keeping gate + down-projection dense), STEM removes ~33% of FFN parameters, making both memory traffic and compute cheaper.
In a standard gated FFN, decoding is often memory-bound: throughput is limited by how fast you can stream FFN weights from HBM. STEM reduces this pressure by eliminating one-third of FFN parameters, directly cutting the "parameter loading cost" per layer during decoding.
Layer-level savings (SwiGLU):
• Decoding parameter load: 3·d·dff → 2·d·dff (≈ 33% reduction)
• Prefill/training FLOPs: reduced by replacing dense up-proj compute with a lookup
STEM uses large layer-local embedding tables, which can be offloaded to CPU memory during inference. Crucially, because embeddings are indexed purely by token IDs, their access pattern is predictable—so they can be prefetched asynchronously and overlapped with GPU compute.
This is exactly what makes STEM practical: you get "big memory" without permanently parking it in HBM.
Prefill is the easy case for overlap: the full prompt tokens are known upfront, so the runtime can prefetch embeddings ahead of time.
Two implementation tricks make this even faster:
Decoding is trickier due to autoregressivity: you only know the next token after completing the current step, so prefetching has less lookahead.
STEM addresses this with a simple but effective observation: token frequencies follow a Zipfian distribution, so a small number of tokens dominate accesses. This enables a memory-efficient LFU cache with >80% hit rate, greatly reducing CPU→GPU transfers during generation.
In this work, we introduced STEM—a static, token-indexed architecture that aims to scale the paramtric capacity in a tractable manner that maintains training stability, improves inference efficiency, and enhances interpretability. The additional parameters helps STEM to outperform the dense baseline on not only knowledge-intensive but also reasoning-heavy downstream tasks. Usually, performance and interpretability are at odds with each other, but STEM manages to achieve both. Thus STEM is a step towards a more interpretable and a scalable architecture.
Looking forward, we see STEM as a building block that can be combined with other architectural innovations such as Mixture-of-Experts (MoE). We want to study the efficacy of STEM at larger model scales. This will help us understand how much of the parametric capacity of large scale models can be offloaded to more statically indexed and more interpretable STEM embeddings.
@misc{sadhukhan2026stemscalingtransformersembedding,
title={STEM: Scaling Transformers with Embedding Modules},
author={Ranajoy Sadhukhan and Sheng Cao and Harry Dong and Changsheng Zhao and Attiano Purpura-Pontoniere and Yuandong Tian and Zechun Liu and Beidi Chen},
year={2026},
eprint={2601.10639},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2601.10639},
}