vortex_torch.indexer.matmul

Classes

GeMM()

General matrix-matrix multiplication (GeMM) dispatcher.

GeMV()

General matrix-vector multiplication (GEMV) dispatcher.

class vortex_torch.indexer.matmul.GeMV[source]

Bases: vOp

General matrix-vector multiplication (GEMV) dispatcher.

This operator computes a piecewise batched matrix-vector product. Let

\[X \in \mathbb{R}^{B \times 1 \times D}, \qquad Y \in \mathbb{R}^{S_{\text{pack}} \times 1 \times D},\]

where the S-axis of \(Y\) is a concatenation of batch-wise segments

\[\begin{split}S_{\text{pack}} = \sum_{i=0}^{B-1} S_i, \qquad Y = \begin{bmatrix} Y_0 \\ Y_1 \\ \vdots \\ Y_{B-1} \end{bmatrix},\end{split}\]

with

\[Y_i \in \mathbb{R}^{S_i \times 1 \times D}, \qquad X_i = X[i, 0, :] \in \mathbb{R}^{1 \times D}.\]

For each batch index \(i \in \{0,\dots,B-1\}\), we define

\[O_i = Y_i X_i^{\mathsf{T}} \in \mathbb{R}^{S_i \times 1 \times 1},\]

and the overall output is the concatenation

\[\begin{split}O = \begin{bmatrix} O_0 \\ O_1 \\ \vdots \\ O_{B-1} \end{bmatrix} \in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1}.\end{split}\]

In the runtime, \(S_{\text{pack}}\) is given by ctx.max_num_pages and the dispatch is keyed by the pair of input formats (x_format, y_format).

_impl_map

Dispatch table keyed by (x_format, y_format). Each entry maps to (callable_impl, resolved_output_format).

Type:

Dict[Tuple[FORMAT, FORMAT], Tuple[Callable, FORMAT]]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

output_format

The output tensor format as determined in profile().

Type:

Optional[FORMAT]

output_buffer

Preallocated output tensor buffer of shape [S_pack, 1, 1].

Type:

Optional[torch.Tensor]

profile(x, y, ctx)[source]

Validate inputs, select an implementation, allocate the output buffer, and return an as_vtensor() view with the resolved format.

The method enforces the logical shapes

  • x: [B, 1, D]

  • y: [S_pack, 1, D]

and checks that the last dimensions match. The output buffer is allocated with shape [S_pack, 1, 1], where S_pack is taken from the runtime context as ctx.max_num_pages.

Parameters:
Return type:

vTensor

execute(x, y, ctx)[source]

Launch the selected GEMV implementation into the internal output buffer.

Expected kernel signature:

impl(x, y, output, ctx)
Parameters:
  • x (torch.Tensor) – Input tensor corresponding to the batched vector(s), with shape [B, 1, D] and on the same device as y and the output.

  • y (torch.Tensor) – Input tensor corresponding to the packed matrix rows, with shape [S_pack, 1, D] and on the same device as x and the output.

  • ctx (Context) – Execution context passed through to the underlying implementation.

Returns:

The output tensor stored in self.output_buffer with shape [S_pack, 1, 1].

Return type:

torch.Tensor

class vortex_torch.indexer.matmul.GeMM[source]

Bases: vOp

General matrix-matrix multiplication (GeMM) dispatcher.

Logically this computes, for each logical S-slice, a matrix-matrix product

\[O[s] = Y[s] X[s]^{\mathsf{T}}, \quad s = 0, \dots, S-1,\]

with slice-wise shapes

\[X[s] \in \mathbb{R}^{N_x \times K}, \quad Y[s] \in \mathbb{R}^{N_y \times K}, \quad O[s] \in \mathbb{R}^{N_y \times N_x}.\]

In the packed 3D representation used by this dispatcher:

  • Y has logical shape [S, N_y, K].

  • X has logical shape [L_x, N_x, K], where the leading dimension \(L_x\) can represent either:

    • a batch axis \(B\) (when x_format == FORMAT.BATCHED), or

    • the same S axis as Y (when x_format is ragged/paged and already laid out per-page).

    This is why the code comments use X: [B/S, N_x, K]: the first dimension is interpreted as either a batch size \(B\) or an S-like logical page index, depending on the format.

  • The output tensor O has logical shape [S, N_y, N_x].

At runtime, the logical S is taken from ctx.max_num_pages, and dispatch is keyed by the pair of tensor formats (x_format, y_format).

_impl_map

Dispatch table keyed by (x_format, y_format). Each entry maps to (callable_impl, resolved_output_format).

Type:

Dict[Tuple[FORMAT, FORMAT], Tuple[Callable, FORMAT]]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

output_format

The output tensor format as determined in profile().

Type:

Optional[FORMAT]

output_buffer

Preallocated output tensor buffer of shape [S, N_y, N_x].

Type:

Optional[torch.Tensor]

profile(x, y, ctx)[source]

Validate inputs, select implementation, allocate the output buffer, and return an as_vtensor() view with the resolved format.

The method enforces that both inputs are rank-3 tensors and that the inner dimension \(K\) matches:

  • x: [B_or_S, N_x, K]

    When x_format == FORMAT.BATCHED, the leading dimension is a batch size \(B\). For ragged/paged formats, it may conceptually coincide with \(S\).

  • y: [S, N_y, K]

The output buffer is allocated with shape [S, N_y, N_x], where S is taken from the runtime context as ctx.max_num_pages.

Parameters:
  • x (vTensor) – Right-hand operand (transposed in the mathematical view), with shape [B_or_S, N_x, K] and a format participating in the (x_format, y_format) dispatch.

  • y (vTensor) – Left-hand operand with shape [S, N_y, K].

  • ctx (Context) – Execution context providing ctx.max_num_pages for the logical S dimension and tracking auxiliary memory.

Returns:

A vTensor view wrapping the allocated output buffer with the resolved output format.

Return type:

vTensor

Raises:

AssertionError – If types are not vTensor, ranks are not 3, the inner dimensions \(K\) do not match, or there is no implementation for the pair (x._format, y._format).

execute(x, y, ctx)[source]

Launch the selected GeMM implementation into the internal buffer.

Expected kernel signature:

impl(x, y, output, ctx)
Parameters:
  • x (torch.Tensor) – Right-hand operand (transposed in the mathematical view), with shape [B_or_S, N_x, K] on the same device as y and the output buffer.

  • y (torch.Tensor) – Left-hand operand with shape [S, N_y, K] on the same device as x and the output buffer.

  • ctx (Context) – Execution context passed through to the underlying implementation.

Returns:

The output tensor stored in self.output_buffer with shape [S, N_y, N_x].

Return type:

torch.Tensor

Raises:

AssertionError – If profile() has not been called (no implementation or buffer), or if there is a device mismatch between x, y and the output buffer.