vortex_torch.flow.algorithms¶
Classes
Block-sparse attention flow with centroid-based routing. |
|
Grouped-query block-sparse attention flow. |
|
GQA-style QUEST sparse attention flow. |
- class vortex_torch.flow.algorithms.BlockSparseAttention[source]¶
Bases:
vFlowBlock-sparse attention flow with centroid-based routing.
This flow implements a simple block-sparse routing strategy inspired by the block-top-k routing used in Kinetics :cite:`sadhukhan2025kinetics` (arXiv:2506.05333). It maintains a per-request centroid over keys and uses query–centroid similarity to select a sparse set of pages.
High-level behavior¶
During
forward_cache(), the flow computes a centroid vector for each request from its key cachecache["k"]and stores the result incache["centroids"]with shape\[\text{cache["centroids"]} \in \mathbb{R}^{B \times 1 \times D},\]where \(B\) is the number of requests and \(D\) is the head dimension.
During
forward_indexer(), the flow:Averages query tokens per request to obtain a single query summary per request,
Applies a generalized matrix–vector multiplication
GeMVbetween the query summaries and the cached centroids to obtain a scalar score for each (request, page),Uses
topKto convert these scores into sparse page indicesoof shape\[o \in \mathbb{R}^{S} \times 1 \times 1},\]Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and
ContextBasewill take care of mapping between per-request page counts and the packed layout automatically.
Cache layout¶
This flow declares a single extra cache tensor via
create_cache():{ "centroids": (1, head_dim) }
The runtime then also allocates
"k"and"v"with inner shapes(page_size, head_dim). As per thevFlowcontract, each cache tensor has two logical views:In
forward_indexer()(page-packed view):\[\text{cache["centroids"]} \sim \mathbb{R}^{S} \times 1 \times D},\]In
forward_cache()(batch-major view):\[\text{cache["centroids"]} \sim \mathbb{R}^{B \times 1 \times D}.\]
References
Bibliography
[sadhukhan2025kinetics]Ranajoy Sadhukhan, Zhuoming Chen, Haizhong Zheng, Yang Zhou, Emma Strubell, Beidi Chen. Kinetics: Rethinking Test-Time Scaling Laws. arXiv:2506.05333, 2025.
- forward_indexer(q, o, cache, ctx)[source]¶
Compute sparse page indices from queries and cached centroids.
- Parameters:
q (torch.Tensor) – Query tensor with shape
[B, H_q, D](typicallytorch.bfloat16), where \(B\) is the batch–head axis, \(H_q\) is the number of query positions per request, and \(D\) is the head dimension.o (torch.Tensor) – Output tensor for sparse page indices with shape
[S_sparse, 1, 1]and integer dtype. It is filled in-place bytopKaccording to the scores computed byGeMV.cache (Dict[str, torch.Tensor]) –
Cache dictionary in the indexer view, where:
cache["k"]andcache["v"]are page-packed key/value tensors,cache["centroids"]is interpreted as[S, 1, D](page-packed centroids).
ctx (ContextBase) – Runtime context carrying page layout, top-k configuration (
topk_val,page_reserved_bos,page_reserved_eos), and other metadata.
Notes
The implementation:
Computes a per-request query summary
\[q_{\mathrm{mean}}[b, 0, :] = \frac{1}{H_q} \sum_{h=0}^{H_q-1} q[b, h, :],\]Applies
GeMVbetweenq_meanandcache["centroids"]to obtain scalar scores per page,Uses
topKto select a sparse set of pages per request and write the corresponding indices intooin the packed sparse layout.
- forward_cache(cache, loc, ctx)[source]¶
Update cache centroids from the key cache in batch-major view.
- Parameters:
cache (Dict[str, torch.Tensor]) –
Cache dictionary in the batch-major view, where:
cache["k"]has shape[B, page_size, D],cache["centroids"]has shape[B, 1, D].
loc (torch.Tensor) – Positional or layout metadata used by
CMeanto aggregate keys into centroids (e.g. page boundaries or valid token masks).ctx (ContextBase) – Runtime context forwarded to the reduction op.
Notes
This method calls
CMeanwithdim=1so that for each request \(b\) it computes a mean over the key axis and writes it tocache["centroids"][b, 0, :]. The exact handling of padding or invalid positions is controlled bylocand the backend implementation ofCMean.
- create_cache(page_size, head_dim)[source]¶
Declare inner shapes for custom cache tensors.
- Parameters:
- Returns:
Mapping from cache tensor names to inner shapes
(r, c). This flow defines a single extra tensor:"centroids"with inner shape(1, head_dim), which becomes[S, 1, head_dim]inforward_indexer(),[B, 1, head_dim]inforward_cache().
- Return type:
- class vortex_torch.flow.algorithms.GQABlockSparseAttention[source]¶
Bases:
vFlowGrouped-query block-sparse attention flow.
This flow uses a GQA-style block-sparse routing: queries are grouped, scored against per-request centroids, normalized with a softmax, then aggregated across groups before a top-k over pages is applied.
Queries
qhave shape[B, H_q, D].Centroids cache
cache["centroids"]has inner shape(1, head_dim)and is viewed as:[S, 1, D]inforward_indexer(),[B, 1, D]inforward_cache().
Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and
ContextBasewill take care of mapping between per-request page counts and the packed layout automatically.
For a design similar in spirit to grouped-query block sparsity, see the GQA sparse attention formulation in:
- forward_indexer(q, o, cache, ctx)[source]¶
Compute sparse page indices from grouped-query scores.
Pipeline¶
Apply
GeMMbetween queries and centroids (o = yx^t):q:[B, H_q, D]cache["centroids"](indexer view):[S, 1, D]score:[S, 1, H_q](logical[S, Ny, Nx])
Apply in-place softmax over the leading (page) axis with a scaling factor
scale:\[\mathrm{softmax}(x \cdot \mathrm{scale})\]Aggregate over the query-group dimension with
Max(dim=2), yielding a single scalar score per page.Use
topKon the aggregated scores to write packed sparse page indices intoowith shape[S_sparse, 1, 1].
- Parameters:
q (torch.Tensor)
o (torch.Tensor)
ctx (ContextBase)
- forward_cache(cache, loc, ctx)[source]¶
Update per-request centroids from the key cache.
cache["k"]:[B, page_size, D](batch-major view)cache["centroids"]:[B, 1, D]
The
CMeanoperator withdim=1computes a mean over the key axis (optionally masked/structured vialoc) and writes the result intocache["centroids"]in-place.- Parameters:
loc (torch.Tensor)
ctx (ContextBase)
- class vortex_torch.flow.algorithms.GQAQuestSparseAttention[source]¶
Bases:
vFlowGQA-style QUEST sparse attention flow.
This flow uses query–envelope matching similar to QUEST sparse attention (see https://arxiv.org/abs/2406.10774). For each request, it maintains per-page max and min envelopes of keys and uses them to compute a conservative upper bound on query–key similarity.
Shapes¶
Queries
q:[B, H_q, D](typically bfloat16).Cache entries (inner shapes as declared in
create_cache()):cache["max"]andcache["min"]:(1, head_dim)→ viewed as[S, 1, D]inforward_indexer(),[B, 1, D]inforward_cache().
cache["k"]: standard key cache with inner shape(page_size, head_dim).
Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and
ContextBasewill take care of mapping between per-request page counts and the packed layout automatically.
Routing intuition¶
For each query and page envelope:
Compute elementwise products with the max and min envelopes.
Take an elementwise maximum of these two products to form a QUEST-style upper bound.
Sum over the feature dimension and then take a max over the grouped-query axis to get a single scalar score per page.
Feed the resulting per-page scores into
topKto obtain sparse page indices.
- forward_indexer(q, o, cache, ctx)[source]¶
Compute sparse page indices using QUEST-style envelope scores.
Pipeline (indexer view)¶
Let:
q:[B, H_q, D]cache["max"]:[S, 1, D]cache["min"]:[S, 1, D]
Steps:
s_max = q * max_envelopes_min = q * min_envelopes = max(s_max, s_min)(elementwise)score = sum(s, dim=D)→[S, H_q, 1]aggr_score = max(score, dim=H_q)→ per-page scalartopKconvertsaggr_scoreinto sparse page indicesoof shape[S_sparse, 1, 1].
- Parameters:
q (torch.Tensor)
o (torch.Tensor)
ctx (ContextBase)
- forward_cache(cache, loc, ctx)[source]¶
Update per-page max/min envelopes from the key cache.
Cache-update view¶
cache["k"]:[B, page_size, D]cache["max"]:[B, 1, D]cache["min"]:[B, 1, D]
The
CMaxandCMinops (withdim=1) take page-wise maxima and minima over keys (optionally masked/structured vialoc) and write the envelopes intocache["max"]andcache["min"].- Parameters:
loc (torch.Tensor)
ctx (ContextBase)