vortex_torch.indexer.matmul¶
Classes
|
General matrix-matrix multiplication (GeMM) dispatcher. |
|
General matrix-vector multiplication (GEMV) dispatcher. |
- class vortex_torch.indexer.matmul.GeMV[source]¶
Bases:
vOpGeneral matrix-vector multiplication (GEMV) dispatcher.
This operator computes a piecewise batched matrix-vector product. Let
\[X \in \mathbb{R}^{B \times 1 \times D}, \qquad Y \in \mathbb{R}^{S_{\text{pack}} \times 1 \times D},\]where the
S-axis of \(Y\) is a concatenation of batch-wise segments\[\begin{split}S_{\text{pack}} = \sum_{i=0}^{B-1} S_i, \qquad Y = \begin{bmatrix} Y_0 \\ Y_1 \\ \vdots \\ Y_{B-1} \end{bmatrix},\end{split}\]with
\[Y_i \in \mathbb{R}^{S_i \times 1 \times D}, \qquad X_i = X[i, 0, :] \in \mathbb{R}^{1 \times D}.\]For each batch index \(i \in \{0,\dots,B-1\}\), we define
\[O_i = Y_i X_i^{\mathsf{T}} \in \mathbb{R}^{S_i \times 1 \times 1},\]and the overall output is the concatenation
\[\begin{split}O = \begin{bmatrix} O_0 \\ O_1 \\ \vdots \\ O_{B-1} \end{bmatrix} \in \mathbb{R}^{S_{\text{pack}} \times 1 \times 1}.\end{split}\]In the runtime, \(S_{\text{pack}}\) is given by
ctx.max_num_pagesand the dispatch is keyed by the pair of input formats(x_format, y_format).- _impl_map¶
Dispatch table keyed by
(x_format, y_format). Each entry maps to(callable_impl, resolved_output_format).
- output_buffer¶
Preallocated output tensor buffer of shape
[S_pack, 1, 1].- Type:
Optional[torch.Tensor]
- profile(x, y, ctx)[source]¶
Validate inputs, select an implementation, allocate the output buffer, and return an
as_vtensor()view with the resolved format.The method enforces the logical shapes
x:[B, 1, D]y:[S_pack, 1, D]
and checks that the last dimensions match. The output buffer is allocated with shape
[S_pack, 1, 1], whereS_packis taken from the runtime context asctx.max_num_pages.
- execute(x, y, ctx)[source]¶
Launch the selected GEMV implementation into the internal output buffer.
Expected kernel signature:
impl(x, y, output, ctx)
- Parameters:
x (torch.Tensor) – Input tensor corresponding to the batched vector(s), with shape
[B, 1, D]and on the same device asyand the output.y (torch.Tensor) – Input tensor corresponding to the packed matrix rows, with shape
[S_pack, 1, D]and on the same device asxand the output.ctx (Context) – Execution context passed through to the underlying implementation.
- Returns:
The output tensor stored in
self.output_bufferwith shape[S_pack, 1, 1].- Return type:
torch.Tensor
- class vortex_torch.indexer.matmul.GeMM[source]¶
Bases:
vOpGeneral matrix-matrix multiplication (GeMM) dispatcher.
Logically this computes, for each logical
S-slice, a matrix-matrix product\[O[s] = Y[s] X[s]^{\mathsf{T}}, \quad s = 0, \dots, S-1,\]with slice-wise shapes
\[X[s] \in \mathbb{R}^{N_x \times K}, \quad Y[s] \in \mathbb{R}^{N_y \times K}, \quad O[s] \in \mathbb{R}^{N_y \times N_x}.\]In the packed 3D representation used by this dispatcher:
Yhas logical shape[S, N_y, K].Xhas logical shape[L_x, N_x, K], where the leading dimension \(L_x\) can represent either:a batch axis \(B\) (when
x_format == FORMAT.BATCHED), orthe same
Saxis asY(whenx_formatis ragged/paged and already laid out per-page).
This is why the code comments use
X: [B/S, N_x, K]: the first dimension is interpreted as either a batch size \(B\) or anS-like logical page index, depending on the format.The output tensor
Ohas logical shape[S, N_y, N_x].
At runtime, the logical
Sis taken fromctx.max_num_pages, and dispatch is keyed by the pair of tensor formats(x_format, y_format).- _impl_map¶
Dispatch table keyed by
(x_format, y_format). Each entry maps to(callable_impl, resolved_output_format).
- output_buffer¶
Preallocated output tensor buffer of shape
[S, N_y, N_x].- Type:
Optional[torch.Tensor]
- profile(x, y, ctx)[source]¶
Validate inputs, select implementation, allocate the output buffer, and return an
as_vtensor()view with the resolved format.The method enforces that both inputs are rank-3 tensors and that the inner dimension \(K\) matches:
x:[B_or_S, N_x, K]When
x_format == FORMAT.BATCHED, the leading dimension is a batch size \(B\). For ragged/paged formats, it may conceptually coincide with \(S\).y:[S, N_y, K]
The output buffer is allocated with shape
[S, N_y, N_x], whereSis taken from the runtime context asctx.max_num_pages.- Parameters:
x (vTensor) – Right-hand operand (transposed in the mathematical view), with shape
[B_or_S, N_x, K]and a format participating in the(x_format, y_format)dispatch.y (vTensor) – Left-hand operand with shape
[S, N_y, K].ctx (Context) – Execution context providing
ctx.max_num_pagesfor the logicalSdimension and tracking auxiliary memory.
- Returns:
A
vTensorview wrapping the allocated output buffer with the resolved output format.- Return type:
- Raises:
AssertionError – If types are not
vTensor, ranks are not 3, the inner dimensions \(K\) do not match, or there is no implementation for the pair(x._format, y._format).
- execute(x, y, ctx)[source]¶
Launch the selected GeMM implementation into the internal buffer.
Expected kernel signature:
impl(x, y, output, ctx)
- Parameters:
x (torch.Tensor) – Right-hand operand (transposed in the mathematical view), with shape
[B_or_S, N_x, K]on the same device asyand the output buffer.y (torch.Tensor) – Left-hand operand with shape
[S, N_y, K]on the same device asxand the output buffer.ctx (Context) – Execution context passed through to the underlying implementation.
- Returns:
The output tensor stored in
self.output_bufferwith shape[S, N_y, N_x].- Return type:
torch.Tensor
- Raises:
AssertionError – If
profile()has not been called (no implementation or buffer), or if there is a device mismatch betweenx,yand the output buffer.