vortex_torch.flow.flow

Classes

vFlow()

Base class for flow-style sparse attention modules.

class vortex_torch.flow.flow.vFlow[source]

Bases: ABC

Base class for flow-style sparse attention modules.

This abstraction is conceptually similar to torch.nn.Module, but specialized for sparse attention flows that:

  • maintain a structured key/value cache,

  • define how to index into sparse pages (top-k style routing), and

  • define how to update / summarize that cache as new pages arrive.

Query tensor

The query tensor q passed to forward_indexer() has logical shape

\[q \in \mathbb{R}^{B \times H_q \times D},\]

where

  • \(B\) is a batch-like axis (commonly batch_size * num_heads),

  • \(H_q\) is the number of query positions per batch/head, and

  • \(D\) is the head dimension.

In practice q is typically stored in torch.bfloat16.

Sparse index tensor

The sparse index tensor o produced by forward_indexer() has logical shape

\[o \in \mathbb{R}^{S_{\text{sparse}} \times 1 \times 1},\]

and stores integer page indices. The packed sparse length is

\[S_{\text{sparse}} = \sum_{i=0}^{B-1} S_{\text{sparse}, i},\]

where for each request \(i\) with \(S_i\) candidate pages,

\[S_{\text{sparse}, i} = \min\Bigl( S_i,\; \text{topk_val} + \text{page_reserved_bos} + \text{page_reserved_eos} \Bigr).\]

Here:

  • topk_val is the number of pages selected by the indexer,

  • page_reserved_bos is the number of always-kept pages at the beginning (BOS region),

  • page_reserved_eos is the number of always-kept pages at the end (EOS region),

and these values are typically provided by the runtime context.

Cache tensors: two logical views

Each cache entry cache[key] (including the standard keys "k" and "v" plus any extra entries declared by create_cache()) is a rank-3 tensor that is viewed in two different logical layouts:

  1. Indexer view (page-packed) — used in forward_indexer():

    \[\text{cache[key]} \sim \mathbb{R}^{S \times r \times c},\]

    \((r, c)\) is the per-key inner shape declared via create_cache() or implicitly for "k"/"v".

    Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and ContextBase will take care of mapping between per-request page counts and the packed layout automatically.

  2. Cache-update view (batch-major) — used in forward_cache():

    \[\text{cache[key]} \sim \mathbb{R}^{B \times r \times c}.\]

    The leading axis is the request/batch index \(B\), while the inner shape \((r, c)\) is the same as in the indexer view.

The runtime (via ContextBase) is responsible for mapping between these two views using indptr arrays and layout metadata.

Cache metadata

Subclasses declare only extra cache tensors via create_cache(), e.g.:

{
    "centroids": (1, head_dim),
    "my_aux_tensor": (page_size, head_dim),
    ...
}

The helper get_cache_meta_info() then injects the standard entries:

\[\begin{split}\text{k} &: (\text{page_size}, \text{head_dim}), \\ \text{v} &: (\text{page_size}, \text{head_dim}),\end{split}\]

so subclasses must not add "k" or "v" themselves.

Token ratio

get_token_ratio() computes a simple proxy for how much cache storage is used (per head) relative to one k/v page:

\[\text{token_ratio} = \sum_{\text{key}} \frac{r_{\text{key}} \cdot c_{\text{key}}} {\text{page_size} \cdot \text{head_dim}}.\]

This ignores the leading dimension (whether \(B\) or \(S\)) and compares only inner shapes to the baseline (page_size, head_dim).

Subclass responsibilities

Concrete flows must implement:

  • forward_indexer(q, o, cache, ctx)(): compute sparse page indices (or routing scores) from queries, using cache in the \(S\) view.

  • forward_cache(cache, loc, ctx)(): update cache tensors using the \(B\)-major view and positional metadata.

  • create_cache(page_size, head_dim)(): declare inner shapes \((r, c)\) for all extra cache tensors (excluding "k" and "v").

abstractmethod forward_indexer(q, o, cache, ctx)[source]

Compute sparse page indices (or equivalent routing information) from queries and cache.

Canonical shapes

  • q (queries):

    \[q \in \mathbb{R}^{B \times H_q \times D},\]

    typically stored in torch.bfloat16.

  • o (sparse indices):

    \[o \in \mathbb{R}^{S_{\text{sparse}} \times 1 \times 1},\]

    integer dtype (e.g. torch.int32 or torch.int64). The packed length \(S_{\text{sparse}}\) is defined in the class docstring.

  • cache[key] (indexer view):

    \[\text{cache[key]} \sim \mathbb{R}^{S \times r \times c},\]

    \((r, c)\) are the per-key inner dimensions obtained from get_cache_meta_info().

  • ctx:

    An instance of ContextBase carrying page layout, indptr arrays, and configuration such as topk_val, page_reserved_bos, and page_reserved_eos.

Contract

Implementations should:

  • interpret cache in the \(S\) view,

  • use q and relevant cache tensors to score/select pages,

  • respect per-request bounds derived from ctx,

  • write the resulting sparse indices (or routing representation) into o in-place.

The exact semantics of the integers stored in o (e.g. absolute page indices vs. offsets) are defined by the runtime convention and must be consistent with downstream kernels.

Parameters:
abstractmethod forward_cache(cache, loc, ctx)[source]

Update or recompute cache tensors in the batch-major view.

Canonical shapes

  • cache[key] (cache-update view):

    \[\text{cache[key]} \sim \mathbb{R}^{B \times r \times c},\]

    where \(B\) is the number of requests and \((r, c)\) are the same inner dimensions as in the indexer view.

  • loc:

    Positional / layout metadata (for example, page indices or token positions) used to decide how to aggregate over pages or tokens when producing per-request summaries.

  • ctx:

    Execution context (same instance type as in forward_indexer()), carrying runtime parameters and layout information.

Contract

Typical operations include recomputing per-request summaries such as:

  • averaging or pooling cache["k"] into a tensor cache["centroids"] of shape [B, r, c],

  • maintaining auxiliary statistics needed by the indexer stage.

Implementations may update any entries in cache in-place, as long as they respect the shapes announced by get_cache_meta_info().

Parameters:
abstractmethod create_cache(page_size, head_dim)[source]

Declare inner shapes for non-"k" / non-"v" cache tensors.

This method does not allocate tensors. It only declares the per-key inner dimensions \((r, c)\); the runtime will attach the appropriate leading axis (\(B\) or \(S\)) depending on whether the cache is used in forward_cache() or forward_indexer().

Parameters:
  • page_size (int) – Number of tokens per page. For the standard "k" and "v" entries, this will be the first dimension.

  • head_dim (int) – Head dimension. For the standard "k" and "v" entries, this will be the second dimension.

Returns:

A mapping from cache tensor names (excluding "k" and "v") to inner shapes (r, c). For example:

{
    "centroids": (1, head_dim),
}

Return type:

Dict[str, Tuple[int, int]]

Notes

The keys "k" and "v" are reserved and must not be present in the returned dictionary. They are added automatically by get_cache_meta_info() with inner shape (page_size, head_dim).

get_cache_meta_info(page_size, head_dim)[source]

Get full cache inner-shape metadata, including "k" and "v".

This wraps create_cache() and injects standard inner shapes for key/value tensors. The resulting dictionary includes:

  • all user-defined entries from create_cache(), and

  • two additional entries:

    • "k" : (page_size, head_dim)

    • "v" : (page_size, head_dim)

Parameters:
  • page_size (int) – Number of tokens per page.

  • head_dim (int) – Head dimension.

Returns:

Mapping from cache tensor names to inner shapes (r, c). The runtime will later prepend either a batch axis B or a packed-page axis S when materializing the tensors.

Return type:

Dict[str, Tuple[int, int]]

Raises:

AssertionError – If create_cache() tries to define entries for the reserved keys "k" or "v".

get_token_ratio(page_size, head_dim)[source]

Compute the relative cache size in “tokens” compared to a k/v page.

Using get_cache_meta_info(), this computes a simple ratio that measures how many equivalent tokens the entire cache consumes per head, relative to a baseline page of shape (page_size, head_dim).

Let \(\mathcal{C}\) be the set of cache tensors with inner shapes \((r_{\text{key}}, c_{\text{key}})\). This method returns

\[\text{token_ratio} = \sum_{\text{key} \in \mathcal{C}} \frac{r_{\text{key}} \cdot c_{\text{key}}} {\text{page_size} \cdot \text{head_dim}}.\]

The leading dimension (\(B\) or \(S\)) is not included in this ratio on purpose; it is a per-page normalization.

Parameters:
  • page_size (int) – Number of tokens per page.

  • head_dim (int) – Head dimension.

Returns:

Total cache size expressed in units of page_size * head_dim (including the standard "k" and "v" entries).

Return type:

float

run_indexer_virtual(group_size, page_size, head_dim)[source]
Parameters:
  • group_size (int)

  • page_size (int)

  • head_dim (int)