vortex_torch.indexer.output_func

Classes

topK()

Piecewise top-k dispatcher for packed sequences with reserved pages.

class vortex_torch.indexer.output_func.topK[source]

Bases: vOp

Piecewise top-k dispatcher for packed sequences with reserved pages.

The input is treated as a rank-3 tensor

\[X \in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1},\]

where the leading dimension \(S_{\text{pack}}\) is a packed concatenation of \(B\) segments:

\[\begin{split}S_{\text{pack}} = \sum_{b=0}^{B-1} S_b, \qquad X = \begin{bmatrix} X_0 \\ X_1 \\ \vdots \\ X_{B-1} \end{bmatrix},\end{split}\]

with

\[X_b \in \mathbb{R}^{S_b \times 1 \times 1}.\]

For each segment \(b\), the operator selects a subset of pages according to scores stored in x (and additional key/value metadata from Context), but always preserves:

  • the first page_reserved_bos pages in that segment,

  • the last page_reserved_eos pages in that segment, and

  • an additional topk_val pages chosen by top-k over the remaining candidates in the segment.

Let \(\mathcal{I}_b \subset [0, S_{\text{pack}})\) denote the index range of segment \(b\). The implementation computes a subset \(\mathcal{J}_b \subset \mathcal{I}_b\) such that:

  • all indices corresponding to the reserved prefix (BOS) and suffix (EOS) pages in \(\mathcal{I}_b\) are included, and

  • up to topk_val additional indices are selected by score.

The result is written into a preallocated output tensor o; the exact layout of o is defined by the upstream contract and the implementation.

Key properties

  • Dispatch is keyed only by the input format x._format.

  • The operation is logically out-of-place, but writes into o in-place.

  • profile() only validates and selects the implementation; it does not allocate or return any buffers.

  • execute() performs the per-segment selection using context metadata (indptr arrays, indices, reserved-page counts, and topk_val).

_impl_map

Dispatch table keyed by x_format. Each entry maps to a callable implementation.

Type:

Dict[FORMAT, Callable]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

profile(x, o, ctx)[source]

Validate input/output tensors and select the implementation.

This method checks:

  • that x and o are both rank-3 vTensor objects,

  • that x has shape [S_pack, 1, 1] (one scalar score per page),

  • that a top-k implementation is registered for x._format, and

  • that x and o reside on the same device.

No buffers are allocated here and nothing is returned; this call simply sets impl so that execute() can run.

Parameters:
  • x (vTensor) – Input tensor carrying per-page scalar scores, with logical shape [S_pack, 1, 1].

  • o (vTensor) – Preallocated output tensor that will be filled in-place by the top-k implementation. Its shape and semantics are defined by the upstream contract and the implementation.

  • ctx (Context) –

    Execution context providing:

    • dense_kv_indptr and sparse_kv_indptr: segment boundaries in the packed axis,

    • dense_kv_indices: indices into underlying storage, and

    • scalar configuration such as batch_size, num_kv_heads, topk_val, page_reserved_bos, and page_reserved_eos.

Raises:

AssertionError – If types, ranks, shapes, formats, or devices are incompatible, or if no implementation is registered for x._format.

Return type:

None

execute(x, o, ctx)[source]

Run the selected piecewise top-k implementation and return o.

This method invokes the implementation on the packed input x using segment boundaries and configuration provided by Context. For each segment in the packed axis, the kernel:

  • identifies the BOS and EOS reserved pages using page_reserved_bos and page_reserved_eos,

  • selects an additional topk_val pages by score, and

  • writes the resulting indices/values into o.

Expected implementation signature:

impl(
    x,
    dense_kv_indptr,
    sparse_kv_indptr,
    dense_kv_indices,
    o,
    batch_heads,
    topk_val,
    page_reserved_bos,
    page_reserved_eos,
    max_num_pages_per_request,
)
Parameters:
  • x (torch.Tensor) – Input tensor carrying per-page scalar scores, compatible with the shape validated in profile().

  • o (torch.Tensor) – Preallocated output tensor; will be filled in-place by the implementation.

  • ctx (Context) – Execution context providing all auxiliary arrays and scalar parameters needed for segment-wise top-k selection.

Returns:

The tensor o after it has been filled by the top-k implementation.

Return type:

torch.Tensor

Raises:

AssertionError – If profile() has not been called and no implementation is available.