vortex_torch.flow.algorithms

Classes

BlockSparseAttention()

Block-sparse attention flow with centroid-based routing.

GQABlockSparseAttention()

Grouped-query block-sparse attention flow.

GQAQuestSparseAttention()

GQA-style QUEST sparse attention flow.

class vortex_torch.flow.algorithms.BlockSparseAttention[source]

Bases: vFlow

Block-sparse attention flow with centroid-based routing.

This flow implements a simple block-sparse routing strategy inspired by the block-top-k routing used in Kinetics :cite:`sadhukhan2025kinetics` (arXiv:2506.05333). It maintains a per-request centroid over keys and uses query–centroid similarity to select a sparse set of pages.

High-level behavior

  • During forward_cache(), the flow computes a centroid vector for each request from its key cache cache["k"] and stores the result in cache["centroids"] with shape

    \[\text{cache["centroids"]} \in \mathbb{R}^{B \times 1 \times D},\]

    where \(B\) is the number of requests and \(D\) is the head dimension.

  • During forward_indexer(), the flow:

    1. Averages query tokens per request to obtain a single query summary per request,

    2. Applies a generalized matrix–vector multiplication GeMV between the query summaries and the cached centroids to obtain a scalar score for each (request, page),

    3. Uses topK to convert these scores into sparse page indices o of shape

      \[o \in \mathbb{R}^{S} \times 1 \times 1},\]

      Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and ContextBase will take care of mapping between per-request page counts and the packed layout automatically.

Cache layout

This flow declares a single extra cache tensor via create_cache():

{
    "centroids": (1, head_dim)
}

The runtime then also allocates "k" and "v" with inner shapes (page_size, head_dim). As per the vFlow contract, each cache tensor has two logical views:

  • In forward_indexer() (page-packed view):

    \[\text{cache["centroids"]} \sim \mathbb{R}^{S} \times 1 \times D},\]
  • In forward_cache() (batch-major view):

    \[\text{cache["centroids"]} \sim \mathbb{R}^{B \times 1 \times D}.\]

References

Bibliography

[sadhukhan2025kinetics]

Ranajoy Sadhukhan, Zhuoming Chen, Haizhong Zheng, Yang Zhou, Emma Strubell, Beidi Chen. Kinetics: Rethinking Test-Time Scaling Laws. arXiv:2506.05333, 2025.

forward_indexer(q, o, cache, ctx)[source]

Compute sparse page indices from queries and cached centroids.

Parameters:
  • q (torch.Tensor) – Query tensor with shape [B, H_q, D] (typically torch.bfloat16), where \(B\) is the batch–head axis, \(H_q\) is the number of query positions per request, and \(D\) is the head dimension.

  • o (torch.Tensor) – Output tensor for sparse page indices with shape [S_sparse, 1, 1] and integer dtype. It is filled in-place by topK according to the scores computed by GeMV.

  • cache (Dict[str, torch.Tensor]) –

    Cache dictionary in the indexer view, where:

    • cache["k"] and cache["v"] are page-packed key/value tensors,

    • cache["centroids"] is interpreted as [S, 1, D] (page-packed centroids).

  • ctx (ContextBase) – Runtime context carrying page layout, top-k configuration (topk_val, page_reserved_bos, page_reserved_eos), and other metadata.

Notes

The implementation:

  1. Computes a per-request query summary

    \[q_{\mathrm{mean}}[b, 0, :] = \frac{1}{H_q} \sum_{h=0}^{H_q-1} q[b, h, :],\]
  2. Applies GeMV between q_mean and cache["centroids"] to obtain scalar scores per page,

  3. Uses topK to select a sparse set of pages per request and write the corresponding indices into o in the packed sparse layout.

forward_cache(cache, loc, ctx)[source]

Update cache centroids from the key cache in batch-major view.

Parameters:
  • cache (Dict[str, torch.Tensor]) –

    Cache dictionary in the batch-major view, where:

    • cache["k"] has shape [B, page_size, D],

    • cache["centroids"] has shape [B, 1, D].

  • loc (torch.Tensor) – Positional or layout metadata used by CMean to aggregate keys into centroids (e.g. page boundaries or valid token masks).

  • ctx (ContextBase) – Runtime context forwarded to the reduction op.

Notes

This method calls CMean with dim=1 so that for each request \(b\) it computes a mean over the key axis and writes it to cache["centroids"][b, 0, :]. The exact handling of padding or invalid positions is controlled by loc and the backend implementation of CMean.

create_cache(page_size, head_dim)[source]

Declare inner shapes for custom cache tensors.

Parameters:
  • page_size (int) – Number of tokens per page (unused here but part of the generic vFlow contract).

  • head_dim (int) – Head dimension \(D\). Used as the second dimension of the centroid tensor.

Returns:

Mapping from cache tensor names to inner shapes (r, c). This flow defines a single extra tensor:

Return type:

Dict[str, Tuple[int, int]]

class vortex_torch.flow.algorithms.GQABlockSparseAttention[source]

Bases: vFlow

Grouped-query block-sparse attention flow.

This flow uses a GQA-style block-sparse routing: queries are grouped, scored against per-request centroids, normalized with a softmax, then aggregated across groups before a top-k over pages is applied.

  • Queries q have shape [B, H_q, D].

  • Centroids cache cache["centroids"] has inner shape (1, head_dim) and is viewed as:

    Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and ContextBase will take care of mapping between per-request page counts and the packed layout automatically.

For a design similar in spirit to grouped-query block sparsity, see the GQA sparse attention formulation in:

forward_indexer(q, o, cache, ctx)[source]

Compute sparse page indices from grouped-query scores.

Pipeline

  1. Apply GeMM between queries and centroids (o = yx^t):

    • q: [B, H_q, D]

    • cache["centroids"] (indexer view): [S, 1, D]

    • score: [S, 1, H_q] (logical [S, Ny, Nx])

  2. Apply in-place softmax over the leading (page) axis with a scaling factor scale:

    \[\mathrm{softmax}(x \cdot \mathrm{scale})\]
  3. Aggregate over the query-group dimension with Max (dim=2), yielding a single scalar score per page.

  4. Use topK on the aggregated scores to write packed sparse page indices into o with shape [S_sparse, 1, 1].

Parameters:
forward_cache(cache, loc, ctx)[source]

Update per-request centroids from the key cache.

  • cache["k"]: [B, page_size, D] (batch-major view)

  • cache["centroids"]: [B, 1, D]

The CMean operator with dim=1 computes a mean over the key axis (optionally masked/structured via loc) and writes the result into cache["centroids"] in-place.

Parameters:
create_cache(page_size, head_dim)[source]

Declare inner shapes for custom cache tensors.

Parameters:
  • page_size (int) – Number of tokens per page (not used directly here).

  • head_dim (int) – Head dimension D for centroids.

Returns:

Custom cache metadata. This flow defines:

  • "centroids": inner shape (1, head_dim).

Return type:

Dict[str, Tuple[int, int]]

class vortex_torch.flow.algorithms.GQAQuestSparseAttention[source]

Bases: vFlow

GQA-style QUEST sparse attention flow.

This flow uses query–envelope matching similar to QUEST sparse attention (see https://arxiv.org/abs/2406.10774). For each request, it maintains per-page max and min envelopes of keys and uses them to compute a conservative upper bound on query–key similarity.

Shapes

  • Queries q: [B, H_q, D] (typically bfloat16).

  • Cache entries (inner shapes as declared in create_cache()):

    • cache["max"] and cache["min"]: (1, head_dim) → viewed as

    • cache["k"]: standard key cache with inner shape (page_size, head_dim).

    Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and ContextBase will take care of mapping between per-request page counts and the packed layout automatically.

Routing intuition

For each query and page envelope:

  1. Compute elementwise products with the max and min envelopes.

  2. Take an elementwise maximum of these two products to form a QUEST-style upper bound.

  3. Sum over the feature dimension and then take a max over the grouped-query axis to get a single scalar score per page.

  4. Feed the resulting per-page scores into topK to obtain sparse page indices.

forward_indexer(q, o, cache, ctx)[source]

Compute sparse page indices using QUEST-style envelope scores.

Pipeline (indexer view)

Let:

  • q: [B, H_q, D]

  • cache["max"]: [S, 1, D]

  • cache["min"]: [S, 1, D]

Steps:

  1. s_max = q * max_envelope

  2. s_min = q * min_envelope

  3. s = max(s_max, s_min) (elementwise)

  4. score = sum(s, dim=D)[S, H_q, 1]

  5. aggr_score = max(score, dim=H_q) → per-page scalar

  6. topK converts aggr_score into sparse page indices o of shape [S_sparse, 1, 1].

Parameters:
forward_cache(cache, loc, ctx)[source]

Update per-page max/min envelopes from the key cache.

Cache-update view

  • cache["k"]: [B, page_size, D]

  • cache["max"]: [B, 1, D]

  • cache["min"]: [B, 1, D]

The CMax and CMin ops (with dim=1) take page-wise maxima and minima over keys (optionally masked/structured via loc) and write the envelopes into cache["max"] and cache["min"].

Parameters:
create_cache(page_size, head_dim)[source]

Declare inner shapes for custom cache tensors.

Parameters:
  • page_size (int) – Number of tokens per page (unused here but part of the vFlow contract).

  • head_dim (int) – Head dimension D used by the envelopes.

Returns:

Custom cache metadata:

  • "max": inner shape (1, head_dim)

  • "min": inner shape (1, head_dim)

Return type:

Dict[str, Tuple[int, int]]