vortex_torch.cache.reduce

Classes

L2Norm([dim])

L2-norm reduction (not RMS) over one inner axis (a Reduce).

Max([dim])

Max reduction over one inner axis (a Reduce).

Mean([dim])

Mean reduction over one inner axis (a Reduce).

Min([dim])

Min reduction over one inner axis (a Reduce).

Reduce([dim])

Generic 1-D reduction over one inner axis of a rank-3 cache tensor.

class Reduce(dim=1)[source]

Bases: vOp

Generic 1-D reduction over one inner axis of a rank-3 cache tensor.

Math:

For input \(X\in\mathbb{R}^{B\times N\times D}\) and a per-axis reduction \(\rho\) (mean / max / min / L2-norm, fixed by the subclass):

\[\begin{split}\begin{aligned} (\text{dim}=1):\quad & Y_{b,0,d} = \rho_{\,0 \le i < N}\, X_{b,i,d}, \\ (\text{dim}=2):\quad & Y_{b,n,0} = \rho_{\,0 \le j < D}\, X_{b,n,j}. \end{aligned}\end{split}\]
__init__:

Reduce(dim=1) — inner axis to reduce, 1 (over \(N\)) or 2 (over \(D\)).

__call__:

op(x, output, loc=loc, ctx=ctx) — runs once per page in forward_cache; x is [B, N, D] and the reduced axis becomes size 1. PAGED iff a PAGED output is supplied, else RAGGED.

Note:

use a concrete subclass — Mean, Max, Min, L2Norm. Cache-side reductions support dim {1, 2} only.

Parameters:

dim (int)

class Mean(dim=1)[source]

Bases: Reduce

Mean reduction over one inner axis (a Reduce).

Math:
\[\begin{split}\begin{aligned} (\text{dim}=1):\quad & Y_{b,0,d} = \frac{1}{N}\sum_{n=0}^{N-1} X_{b,n,d}, \\ (\text{dim}=2):\quad & Y_{b,n,0} = \frac{1}{D}\sum_{d=0}^{D-1} X_{b,n,d}. \end{aligned}\end{split}\]
__init__:

Mean(dim=1) — axis to reduce (1\(N\), 2\(D\)).

Parameters:

dim (int)

class Max(dim=1)[source]

Bases: Reduce

Max reduction over one inner axis (a Reduce).

Math:
\[\begin{split}\begin{aligned} (\text{dim}=1):\quad & Y_{b,0,d} = \max_{0 \le n < N} X_{b,n,d}, \\ (\text{dim}=2):\quad & Y_{b,n,0} = \max_{0 \le d < D} X_{b,n,d}. \end{aligned}\end{split}\]
__init__:

Max(dim=1) — axis to reduce (1\(N\), 2\(D\)).

Parameters:

dim (int)

class Min(dim=1)[source]

Bases: Reduce

Min reduction over one inner axis (a Reduce).

Math:
\[\begin{split}\begin{aligned} (\text{dim}=1):\quad & Y_{b,0,d} = \min_{0 \le n < N} X_{b,n,d}, \\ (\text{dim}=2):\quad & Y_{b,n,0} = \min_{0 \le d < D} X_{b,n,d}. \end{aligned}\end{split}\]
__init__:

Min(dim=1) — axis to reduce (1\(N\), 2\(D\)).

Parameters:

dim (int)

class L2Norm(dim=1)[source]

Bases: Reduce

L2-norm reduction (not RMS) over one inner axis (a Reduce).

Math:
\[\begin{split}\begin{aligned} (\text{dim}=1):\quad & Y_{b,0,d} = \Big(\sum_{n=0}^{N-1} X_{b,n,d}^2\Big)^{1/2}, \\ (\text{dim}=2):\quad & Y_{b,n,0} = \Big(\sum_{d=0}^{D-1} X_{b,n,d}^2\Big)^{1/2}. \end{aligned}\end{split}\]
__init__:

L2Norm(dim=1) — axis to reduce (1\(N\), 2\(D\)).

Note:

a pure \(L_2\) norm (no division by element count) — not RMS.

Parameters:

dim (int)