vortex_torch.flow.flow¶
Classes
|
Base class for flow-style sparse attention modules. |
- class vortex_torch.flow.flow.vFlow[source]¶
Bases:
ABCBase class for flow-style sparse attention modules.
This abstraction is conceptually similar to
torch.nn.Module, but specialized for sparse attention flows that:maintain a structured key/value cache,
define how to index into sparse pages (top-k style routing), and
define how to update / summarize that cache as new pages arrive.
Query tensor¶
The query tensor
qpassed toforward_indexer()has logical shape\[q \in \mathbb{R}^{B \times H_q \times D},\]where
\(B\) is a batch-like axis (commonly
batch_size * num_heads),\(H_q\) is the number of query positions per batch/head, and
\(D\) is the head dimension.
In practice
qis typically stored intorch.bfloat16.Sparse index tensor¶
The sparse index tensor
oproduced byforward_indexer()has logical shape\[o \in \mathbb{R}^{S_{\text{sparse}} \times 1 \times 1},\]and stores integer page indices. The packed sparse length is
\[S_{\text{sparse}} = \sum_{i=0}^{B-1} S_{\text{sparse}, i},\]where for each request \(i\) with \(S_i\) candidate pages,
\[S_{\text{sparse}, i} = \min\Bigl( S_i,\; \text{topk_val} + \text{page_reserved_bos} + \text{page_reserved_eos} \Bigr).\]Here:
topk_valis the number of pages selected by the indexer,page_reserved_bosis the number of always-kept pages at the beginning (BOS region),page_reserved_eosis the number of always-kept pages at the end (EOS region),
and these values are typically provided by the runtime context.
Cache tensors: two logical views¶
Each cache entry
cache[key](including the standard keys"k"and"v"plus any extra entries declared bycreate_cache()) is a rank-3 tensor that is viewed in two different logical layouts:Indexer view (page-packed) — used in
forward_indexer():\[\text{cache[key]} \sim \mathbb{R}^{S \times r \times c},\]\((r, c)\) is the per-key inner shape declared via
create_cache()or implicitly for"k"/"v".Here \(S\) is the leading page axis. Internally it is a packed axis (often denoted \(S_{\mathrm{pack}}\)), obtained by concatenating the pages from all requests. As a user, you can simply think of \(S\) as “the number of pages for this request”; the vFlow kernels and
ContextBasewill take care of mapping between per-request page counts and the packed layout automatically.Cache-update view (batch-major) — used in
forward_cache():\[\text{cache[key]} \sim \mathbb{R}^{B \times r \times c}.\]The leading axis is the request/batch index \(B\), while the inner shape \((r, c)\) is the same as in the indexer view.
The runtime (via
ContextBase) is responsible for mapping between these two views using indptr arrays and layout metadata.Cache metadata¶
Subclasses declare only extra cache tensors via
create_cache(), e.g.:{ "centroids": (1, head_dim), "my_aux_tensor": (page_size, head_dim), ... }
The helper
get_cache_meta_info()then injects the standard entries:\[\begin{split}\text{k} &: (\text{page_size}, \text{head_dim}), \\ \text{v} &: (\text{page_size}, \text{head_dim}),\end{split}\]so subclasses must not add
"k"or"v"themselves.Token ratio¶
get_token_ratio()computes a simple proxy for how much cache storage is used (per head) relative to onek/vpage:\[\text{token_ratio} = \sum_{\text{key}} \frac{r_{\text{key}} \cdot c_{\text{key}}} {\text{page_size} \cdot \text{head_dim}}.\]This ignores the leading dimension (whether \(B\) or \(S\)) and compares only inner shapes to the baseline
(page_size, head_dim).Subclass responsibilities¶
Concrete flows must implement:
forward_indexer(q, o, cache, ctx)(): compute sparse page indices (or routing scores) from queries, using cache in the \(S\) view.forward_cache(cache, loc, ctx)(): update cache tensors using the \(B\)-major view and positional metadata.create_cache(page_size, head_dim)(): declare inner shapes \((r, c)\) for all extra cache tensors (excluding"k"and"v").
- abstractmethod forward_indexer(q, o, cache, ctx)[source]¶
Compute sparse page indices (or equivalent routing information) from queries and cache.
Canonical shapes¶
q(queries):\[q \in \mathbb{R}^{B \times H_q \times D},\]typically stored in
torch.bfloat16.o(sparse indices):\[o \in \mathbb{R}^{S_{\text{sparse}} \times 1 \times 1},\]integer dtype (e.g.
torch.int32ortorch.int64). The packed length \(S_{\text{sparse}}\) is defined in the class docstring.cache[key](indexer view):\[\text{cache[key]} \sim \mathbb{R}^{S \times r \times c},\]\((r, c)\) are the per-key inner dimensions obtained from
get_cache_meta_info().ctx:An instance of
ContextBasecarrying page layout, indptr arrays, and configuration such astopk_val,page_reserved_bos, andpage_reserved_eos.
Contract¶
Implementations should:
interpret
cachein the \(S\) view,use
qand relevant cache tensors to score/select pages,respect per-request bounds derived from
ctx,write the resulting sparse indices (or routing representation) into
oin-place.
The exact semantics of the integers stored in
o(e.g. absolute page indices vs. offsets) are defined by the runtime convention and must be consistent with downstream kernels.- Parameters:
q (torch.Tensor)
o (torch.Tensor)
ctx (ContextBase)
- abstractmethod forward_cache(cache, loc, ctx)[source]¶
Update or recompute cache tensors in the batch-major view.
Canonical shapes¶
cache[key](cache-update view):\[\text{cache[key]} \sim \mathbb{R}^{B \times r \times c},\]where \(B\) is the number of requests and \((r, c)\) are the same inner dimensions as in the indexer view.
loc:Positional / layout metadata (for example, page indices or token positions) used to decide how to aggregate over pages or tokens when producing per-request summaries.
ctx:Execution context (same instance type as in
forward_indexer()), carrying runtime parameters and layout information.
Contract¶
Typical operations include recomputing per-request summaries such as:
averaging or pooling
cache["k"]into a tensorcache["centroids"]of shape[B, r, c],maintaining auxiliary statistics needed by the indexer stage.
Implementations may update any entries in
cachein-place, as long as they respect the shapes announced byget_cache_meta_info().- Parameters:
loc (torch.Tensor)
ctx (ContextBase)
- abstractmethod create_cache(page_size, head_dim)[source]¶
Declare inner shapes for non-
"k"/ non-"v"cache tensors.This method does not allocate tensors. It only declares the per-key inner dimensions \((r, c)\); the runtime will attach the appropriate leading axis (\(B\) or \(S\)) depending on whether the cache is used in
forward_cache()orforward_indexer().- Parameters:
- Returns:
A mapping from cache tensor names (excluding
"k"and"v") to inner shapes(r, c). For example:{ "centroids": (1, head_dim), }
- Return type:
Notes
The keys
"k"and"v"are reserved and must not be present in the returned dictionary. They are added automatically byget_cache_meta_info()with inner shape(page_size, head_dim).
- get_cache_meta_info(page_size, head_dim)[source]¶
Get full cache inner-shape metadata, including
"k"and"v".This wraps
create_cache()and injects standard inner shapes for key/value tensors. The resulting dictionary includes:all user-defined entries from
create_cache(), andtwo additional entries:
"k":(page_size, head_dim)"v":(page_size, head_dim)
- Parameters:
- Returns:
Mapping from cache tensor names to inner shapes
(r, c). The runtime will later prepend either a batch axisBor a packed-page axisSwhen materializing the tensors.- Return type:
- Raises:
AssertionError – If
create_cache()tries to define entries for the reserved keys"k"or"v".
- get_token_ratio(page_size, head_dim)[source]¶
Compute the relative cache size in “tokens” compared to a k/v page.
Using
get_cache_meta_info(), this computes a simple ratio that measures how many equivalent tokens the entire cache consumes per head, relative to a baseline page of shape(page_size, head_dim).Let \(\mathcal{C}\) be the set of cache tensors with inner shapes \((r_{\text{key}}, c_{\text{key}})\). This method returns
\[\text{token_ratio} = \sum_{\text{key} \in \mathcal{C}} \frac{r_{\text{key}} \cdot c_{\text{key}}} {\text{page_size} \cdot \text{head_dim}}.\]The leading dimension (\(B\) or \(S\)) is not included in this ratio on purpose; it is a per-page normalization.