vortex_torch.cache.matmul¶
Classes
|
General matrix-matrix multiplication dispatcher for page/token-tiled layouts. |
- class vortex_torch.cache.matmul.GeMM[source]¶
Bases:
vOpGeneral matrix-matrix multiplication dispatcher for page/token-tiled layouts.
This operator computes a batched GEMM of the form
\[O_b = Y_b X_b^\top, \qquad b = 0, \dots, B-1,\]where, for each batch index \(b\),
\(X_b \in \mathbb{R}^{N_x \times K}\),
\(Y_b \in \mathbb{R}^{N_y \times K}\), and
\(O_b \in \mathbb{R}^{N_y \times N_x}\).
In the logical 3D layout used by this dispatcher, the tensors have shapes
\[\begin{split}X &\in \mathbb{R}^{B \times N_x \times K}, \\ Y &\in \mathbb{R}^{B \times N_y \times K}, \\ O &\in \mathbb{R}^{B \times N_y \times N_x},\end{split}\]where the leading dimension \(B\) is a batch-like axis typically derived from the runtime (for example,
max_new_tokens_per_batch * head_numin an attention-style kernel).Dispatch is based on the triplet of tensor formats
(x_format, y_format, o_format)and a registry mapping:(x_format, y_format, o_format) -> (impl, resolved_output_format)
Policy¶
If
outputisNone:If
outputis provided:profile()requires an exact implementation key for(x_fmt, y_fmt, o_fmt).The shape of
outputmust be rank-3 with last two dimensions(N_y, N_x).Device consistency is enforced across
x,yandoutput.
Additionally, the shared inner dimension \(K\) must match:
\[K_x = x.\text{shape}[2], \quad K_y = y.\text{shape}[2], \quad K_x = K_y.\]- _impl_map¶
Dispatch table keyed by
(x_format, y_format, o_format). Each entry maps to(callable_impl, resolved_output_format).
- output_buffer¶
Internal output buffer allocated when
outputisNone.- Type:
Optional[torch.Tensor]
- profile(x, y, output, loc, ctx)[source]¶
Validate inputs, resolve the GEMM implementation and output format, and optionally allocate an internal output buffer.
The logical shapes are:
x:[B, N_x, K]y:[B, N_y, K]output(if provided):[B_out, N_y, N_x]
with the constraint that the inner dimension \(K\) matches:
\[x.\text{shape}[2] = y.\text{shape}[2].\]The auxiliary tensor
loccarries per-position or per-tile metadata used by the implementation (for example, page indices or tiling information); its shape and semantics are kernel-defined.- Parameters:
x (vTensor) – Right-hand operand in
Y @ X^T, with logical shape[B, N_x, K].y (vTensor) – Left-hand operand in
Y @ X^T, with logical shape[B, N_y, K].output (Optional[vTensor]) – Optional preallocated output tensor. If
None, an internal buffer with shape[B, N_y, N_x]is allocated usingctx.max_new_tokens_per_batch * ctx.head_numfor the leading dimension and a RAGGED-output implementation is selected. If notNone, this tensor must have rank 3 and last two dimensions(N_y, N_x), with a format compatible with_impl_map.loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the GEMM implementation.
ctx (Context) – Execution context that provides the runtime value of
Band is used for auxiliary memory accounting.
- Returns:
A
vTensorview representing the resolved output: either the providedoutputor an internally allocated buffer.- Return type:
- Raises:
AssertionError – If types, ranks, inner-dimension match, formats, shapes, or devices are incompatible, or if no implementation is found in
_impl_map.
- execute(x, y, output, loc, ctx)[source]¶
Execute the selected GEMM implementation.
This method assumes that
profile()has already selected an implementation and, if needed, allocated an internal output buffer.- Parameters:
x (torch.Tensor) – Plain right-hand operand tensor with shape compatible with the
vTensorvalidated inprofile().y (torch.Tensor) – Plain left-hand operand tensor with shape compatible with the
vTensorvalidated inprofile().output (Optional[torch.Tensor]) – Optional preallocated output tensor. If
None, the internal buffer created duringprofile()will be used.loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the GEMM implementation.
ctx (Context) – Execution context forwarded to the implementation.
- Returns:
The output tensor written by the implementation: either the provided
outputor the internal buffer.- Return type:
torch.Tensor
- Raises:
AssertionError – If
profile()has not been called and no implementation or internal buffer is available.