vortex_torch.cache.matmul

Classes

GeMM()

Per-block matrix–matrix product, \(O_b = Y_b X_b^{\top}\) (cache side).

class GeMM[source]

Bases: vOp

Per-block matrix–matrix product, \(O_b = Y_b X_b^{\top}\) (cache side).

Math:

For \(X\in\mathbb{R}^{B\times N_x\times K}\), \(Y\in\mathbb{R}^{B\times N_y\times K}\):

\[O_{b,a,c} = \sum_{k=0}^{K-1} Y_{b,a,k}\,X_{b,c,k}, \qquad O\in\mathbb{R}^{B\times N_y\times N_x}.\]
__init__:

GeMM() — no arguments.

__call__:

o = op(x, y, output, loc=loc, ctx=ctx)x [B, N_x, K], y [B, N_y, K] (matching K); o [B, N_y, N_x]. PAGED iff a PAGED output is supplied, else RAGGED.