vortex_torch.cache.matmul

Classes

GeMM()

General matrix-matrix multiplication dispatcher for page/token-tiled layouts.

class vortex_torch.cache.matmul.GeMM[source]

Bases: vOp

General matrix-matrix multiplication dispatcher for page/token-tiled layouts.

This operator computes a batched GEMM of the form

\[O_b = Y_b X_b^\top, \qquad b = 0, \dots, B-1,\]

where, for each batch index \(b\),

  • \(X_b \in \mathbb{R}^{N_x \times K}\),

  • \(Y_b \in \mathbb{R}^{N_y \times K}\), and

  • \(O_b \in \mathbb{R}^{N_y \times N_x}\).

In the logical 3D layout used by this dispatcher, the tensors have shapes

\[\begin{split}X &\in \mathbb{R}^{B \times N_x \times K}, \\ Y &\in \mathbb{R}^{B \times N_y \times K}, \\ O &\in \mathbb{R}^{B \times N_y \times N_x},\end{split}\]

where the leading dimension \(B\) is a batch-like axis typically derived from the runtime (for example, max_new_tokens_per_batch * head_num in an attention-style kernel).

Dispatch is based on the triplet of tensor formats (x_format, y_format, o_format) and a registry mapping:

(x_format, y_format, o_format) -> (impl, resolved_output_format)

Policy

  • If output is None:

    • profile() selects an implementation with o_format == FORMAT.RAGGED, i.e. a key (x_fmt, y_fmt, FORMAT.RAGGED) in _impl_map.

    • An internal buffer is allocated with logical shape [B, N_y, N_x] on the same device and with the same dtype as x.

  • If output is provided:

    • profile() requires an exact implementation key for (x_fmt, y_fmt, o_fmt).

    • The shape of output must be rank-3 with last two dimensions (N_y, N_x).

    • Device consistency is enforced across x, y and output.

Additionally, the shared inner dimension \(K\) must match:

\[K_x = x.\text{shape}[2], \quad K_y = y.\text{shape}[2], \quad K_x = K_y.\]
_impl_map

Dispatch table keyed by (x_format, y_format, o_format). Each entry maps to (callable_impl, resolved_output_format).

Type:

Dict[Tuple[FORMAT, FORMAT, FORMAT], Tuple[Callable, FORMAT]]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

output_format

The output tensor format as determined in profile().

Type:

Optional[FORMAT]

output_buffer

Internal output buffer allocated when output is None.

Type:

Optional[torch.Tensor]

profile(x, y, output, loc, ctx)[source]

Validate inputs, resolve the GEMM implementation and output format, and optionally allocate an internal output buffer.

The logical shapes are:

  • x: [B, N_x, K]

  • y: [B, N_y, K]

  • output (if provided): [B_out, N_y, N_x]

with the constraint that the inner dimension \(K\) matches:

\[x.\text{shape}[2] = y.\text{shape}[2].\]

The auxiliary tensor loc carries per-position or per-tile metadata used by the implementation (for example, page indices or tiling information); its shape and semantics are kernel-defined.

Parameters:
  • x (vTensor) – Right-hand operand in Y @ X^T, with logical shape [B, N_x, K].

  • y (vTensor) – Left-hand operand in Y @ X^T, with logical shape [B, N_y, K].

  • output (Optional[vTensor]) – Optional preallocated output tensor. If None, an internal buffer with shape [B, N_y, N_x] is allocated using ctx.max_new_tokens_per_batch * ctx.head_num for the leading dimension and a RAGGED-output implementation is selected. If not None, this tensor must have rank 3 and last two dimensions (N_y, N_x), with a format compatible with _impl_map.

  • loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the GEMM implementation.

  • ctx (Context) – Execution context that provides the runtime value of B and is used for auxiliary memory accounting.

Returns:

A vTensor view representing the resolved output: either the provided output or an internally allocated buffer.

Return type:

vTensor

Raises:

AssertionError – If types, ranks, inner-dimension match, formats, shapes, or devices are incompatible, or if no implementation is found in _impl_map.

execute(x, y, output, loc, ctx)[source]

Execute the selected GEMM implementation.

This method assumes that profile() has already selected an implementation and, if needed, allocated an internal output buffer.

Parameters:
  • x (torch.Tensor) – Plain right-hand operand tensor with shape compatible with the vTensor validated in profile().

  • y (torch.Tensor) – Plain left-hand operand tensor with shape compatible with the vTensor validated in profile().

  • output (Optional[torch.Tensor]) – Optional preallocated output tensor. If None, the internal buffer created during profile() will be used.

  • loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the GEMM implementation.

  • ctx (Context) – Execution context forwarded to the implementation.

Returns:

The output tensor written by the implementation: either the provided output or the internal buffer.

Return type:

torch.Tensor

Raises:

AssertionError – If profile() has not been called and no implementation or internal buffer is available.