vortex_torch.indexer.output_func¶
Classes
|
Piecewise top-k dispatcher for packed sequences with reserved pages. |
- class vortex_torch.indexer.output_func.topK[source]¶
Bases:
vOpPiecewise top-k dispatcher for packed sequences with reserved pages.
The input is treated as a rank-3 tensor
\[X \in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1},\]where the leading dimension \(S_{\text{pack}}\) is a packed concatenation of \(B\) segments:
\[\begin{split}S_{\text{pack}} = \sum_{b=0}^{B-1} S_b, \qquad X = \begin{bmatrix} X_0 \\ X_1 \\ \vdots \\ X_{B-1} \end{bmatrix},\end{split}\]with
\[X_b \in \mathbb{R}^{S_b \times 1 \times 1}.\]For each segment \(b\), the operator selects a subset of pages according to scores stored in
x(and additional key/value metadata fromContext), but always preserves:the first
page_reserved_bospages in that segment,the last
page_reserved_eospages in that segment, andan additional
topk_valpages chosen by top-k over the remaining candidates in the segment.
Let \(\mathcal{I}_b \subset [0, S_{\text{pack}})\) denote the index range of segment \(b\). The implementation computes a subset \(\mathcal{J}_b \subset \mathcal{I}_b\) such that:
all indices corresponding to the reserved prefix (BOS) and suffix (EOS) pages in \(\mathcal{I}_b\) are included, and
up to
topk_valadditional indices are selected by score.
The result is written into a preallocated output tensor
o; the exact layout ofois defined by the upstream contract and the implementation.Key properties¶
Dispatch is keyed only by the input format
x._format.The operation is logically out-of-place, but writes into
oin-place.profile()only validates and selects the implementation; it does not allocate or return any buffers.execute()performs the per-segment selection using context metadata (indptr arrays, indices, reserved-page counts, andtopk_val).
- _impl_map¶
Dispatch table keyed by
x_format. Each entry maps to a callable implementation.- Type:
Dict[FORMAT, Callable]
- profile(x, o, ctx)[source]¶
Validate input/output tensors and select the implementation.
This method checks:
that
xandoare both rank-3vTensorobjects,that
xhas shape[S_pack, 1, 1](one scalar score per page),that a top-k implementation is registered for
x._format, andthat
xandoreside on the same device.
No buffers are allocated here and nothing is returned; this call simply sets
implso thatexecute()can run.- Parameters:
x (vTensor) – Input tensor carrying per-page scalar scores, with logical shape
[S_pack, 1, 1].o (vTensor) – Preallocated output tensor that will be filled in-place by the top-k implementation. Its shape and semantics are defined by the upstream contract and the implementation.
ctx (Context) –
Execution context providing:
dense_kv_indptrandsparse_kv_indptr: segment boundaries in the packed axis,dense_kv_indices: indices into underlying storage, andscalar configuration such as
batch_size,num_kv_heads,topk_val,page_reserved_bos, andpage_reserved_eos.
- Raises:
AssertionError – If types, ranks, shapes, formats, or devices are incompatible, or if no implementation is registered for
x._format.- Return type:
None
- execute(x, o, ctx)[source]¶
Run the selected piecewise top-k implementation and return
o.This method invokes the implementation on the packed input
xusing segment boundaries and configuration provided byContext. For each segment in the packed axis, the kernel:identifies the BOS and EOS reserved pages using
page_reserved_bosandpage_reserved_eos,selects an additional
topk_valpages by score, andwrites the resulting indices/values into
o.
Expected implementation signature:
impl( x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, o, batch_heads, topk_val, page_reserved_bos, page_reserved_eos, max_num_pages_per_request, )
- Parameters:
x (torch.Tensor) – Input tensor carrying per-page scalar scores, compatible with the shape validated in
profile().o (torch.Tensor) – Preallocated output tensor; will be filled in-place by the implementation.
ctx (Context) – Execution context providing all auxiliary arrays and scalar parameters needed for segment-wise top-k selection.
- Returns:
The tensor
oafter it has been filled by the top-k implementation.- Return type:
torch.Tensor
- Raises:
AssertionError – If
profile()has not been called and no implementation is available.