vortex_torch.indexer.matmul

Classes

GeMM()

Per-page matrix–matrix product, \(O[s] = Y[s]\,X[s]^{\top}\).

GeMV()

Per-request batched matrix–vector product, \(O = Y X^{\top}\).

class GeMV[source]

Bases: vOp

Per-request batched matrix–vector product, \(O = Y X^{\top}\).

Math:

Batched query \(X\in\mathbb{R}^{B\times 1\times D}\), packed pages \(Y\in\mathbb{R}^{S\times 1\times D}\); for page \(s\) in request \(i(s)\),

\[O_{s,0,0} = \sum_{d=0}^{D-1} Y_{s,0,d}\,X_{i(s),0,d} = \langle Y_s,\, X_{i(s)} \rangle, \qquad O\in\mathbb{R}^{S\times 1\times 1}.\]
__init__:

GeMV() — no arguments.

__call__:

o = op(x, y, ctx=ctx)x is [B, 1, D], y is [S, 1, D] (matching D); returns o [S, 1, 1]. Output is BATCHED iff both inputs are, else RAGGED.

class GeMM[source]

Bases: vOp

Per-page matrix–matrix product, \(O[s] = Y[s]\,X[s]^{\top}\).

Math:

\(Y\in\mathbb{R}^{S\times N_y\times K}\), \(X\in\mathbb{R}^{(B\text{ or }S)\times N_x\times K}\); per page \(s\) this is \(O_s = Y_s X_s^{\top}\) (i.e. GeMM(x, y) = y xᵀ):

\[O_{s,a,b} = \sum_{k=0}^{K-1} Y_{s,a,k}\,X_{s,b,k}, \qquad O\in\mathbb{R}^{S\times N_y\times N_x}.\]
__init__:

GeMM() — no arguments.

__call__:

o = op(x, y, ctx=ctx)x is [B|S, N_x, K], y is [S, N_y, K] (matching K); returns o [S, N_y, N_x]. Output is BATCHED iff both inputs are, else RAGGED.