vortex_torch.cache.reduce

Classes

L2Norm([dim])

L2-norm reduction (not RMS) over a single logical axis.

Max([dim])

Max reduction over a single logical axis.

Mean([dim])

Mean reduction over a single logical axis.

Min([dim])

Min reduction over a single logical axis.

Reduce([dim])

Generic reduction dispatcher over the last two logical axes.

class vortex_torch.cache.reduce.Reduce(dim=1)[source]

Bases: vOp

Generic reduction dispatcher over the last two logical axes.

This dispatcher covers a family of reductions (mean/max/min/L2-norm/sum, etc.) on rank-3 tensors. The input is treated as

\[X \in \mathbb{R}^{B \times N \times D},\]

where:

  • \(B\) is a leading batch-like axis (typically derived from the runtime, e.g. max_new_tokens_per_batch * head_num),

  • \(N\) is a sequence or position dimension, and

  • \(D\) is a feature/channel dimension.

The reduction dimension is chosen by dim:

  • dim == 1 (row-wise reduction over \(N\)):

    \[\text{out} \in \mathbb{R}^{B \times 1 \times D},\]
  • dim == 2 (column-wise reduction over \(D\)):

    \[\text{out} \in \mathbb{R}^{B \times N \times 1}.\]

The exact reduction operation (mean, max, min, L2-norm, sum, etc.) is encoded in reduce_type and interpreted by the implementation.

Dispatch is based on the pair of tensor formats (x_format, o_format) and a registry mapping:

(x_format, o_format) -> (impl, resolved_output_format)

Policy

  • If output is None:

    • profile() selects an implementation for (x_fmt, FORMAT.RAGGED) (i.e. with RAGGED output).

    • An internal buffer is allocated with logical shape [B, exp_N, exp_D], where:

      • for dim == 1: exp_N = 1, exp_D = D,

      • for dim == 2: exp_N = N, exp_D = 1.

  • If output is provided:

    • profile() requires an exact implementation key for (x_fmt, o_fmt).

    • The shape of output must match the expected (exp_N, exp_D) given dim.

    • Device consistency is enforced between x and output.

_impl_map

Dispatch table keyed by (x_format, o_format). Each entry maps to (callable_impl, resolved_output_format).

Type:

Dict[Tuple[FORMAT, FORMAT], Tuple[Callable, FORMAT]]

dim

Reduction dimension in the logical 3D tensor. Must be either:

  • 1 for row-wise reduction over \(N\), or

  • 2 for column-wise reduction over \(D\).

Type:

int

reduce_type

Enum describing which reduction to perform (mean, max, min, L2-norm, sum, etc.).

Type:

Optional[ReduceType]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

output_format

The output tensor format as determined in profile().

Type:

Optional[FORMAT]

output_buffer

Internal output buffer allocated when output is None.

Type:

Optional[torch.Tensor]

profile(x, output, loc, ctx)[source]

Validate inputs, resolve the reduction implementation and output format, and optionally allocate an internal output buffer.

The input tensor x is expected to have logical shape [B, N, D]. According to dim, the expected output logical shape is:

  • dim == 1: [B, 1, D]

  • dim == 2: [B, N, 1]

The auxiliary tensor loc carries per-position metadata used by the implementation; its shape and semantics are implementation-defined.

Parameters:
  • x (vTensor) – Input tensor with logical shape [B, N, D].

  • output (Optional[vTensor]) – Optional preallocated output tensor. If None, an internal buffer with shape [B, exp_N, exp_D] is allocated using ctx.max_new_tokens_per_batch * ctx.head_num for B and a RAGGED-output implementation is selected. If not None, this tensor must have rank 3 and shape compatible with dim as described above.

  • loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the reduction implementation.

  • ctx (Context) – Execution context that provides the runtime value of B and is used for auxiliary memory accounting.

Returns:

A vTensor view representing the resolved output: either the provided output or an internally allocated buffer.

Return type:

vTensor

Raises:

AssertionError – If types, ranks, formats, shapes, or devices are incompatible, or if no implementation is found in _impl_map.

execute(x, output, loc, ctx)[source]

Execute the selected reduction implementation.

This method assumes that profile() has already selected an implementation and, if needed, allocated an internal output buffer.

Parameters:
  • x (torch.Tensor) – Plain input tensor with shape compatible with the vTensor validated in profile().

  • output (Optional[torch.Tensor]) – Optional preallocated output tensor. If None, the internal buffer created during profile() will be used.

  • loc (torch.Tensor) – Auxiliary tensor carrying metadata required by the reduction implementation.

  • ctx (Context) – Execution context forwarded to the implementation.

Returns:

The output tensor written by the implementation: either the provided output or the internal buffer.

Return type:

torch.Tensor

Raises:

AssertionError – If profile() has not been called and no implementation or internal buffer is available.

Parameters:

dim (int)

class vortex_torch.cache.reduce.Mean(dim=1)[source]

Bases: Reduce

Mean reduction over a single logical axis.

This operator reduces a rank-3 tensor

\[X \in \mathbb{R}^{B \times N \times D}\]

by taking the arithmetic mean along one of the inner dimensions, as configured by dim:

  • dim == 1: row-wise mean over \(N\), producing

    \[Y \in \mathbb{R}^{B \times 1 \times D}, \qquad Y[b, 0, d] = \frac{1}{N} \sum_{n=0}^{N-1} X[b, n, d].\]
  • dim == 2: column-wise mean over \(D\), producing

    \[Y \in \mathbb{R}^{B \times N \times 1}, \qquad Y[b, n, 0] = \frac{1}{D} \sum_{d=0}^{D-1} X[b, n, d].\]
Parameters:

dim (int, optional) – Logical reduction dimension. Must be 1 (reduce over \(N\)) or 2 (reduce over \(D\)). Default is 1.

class vortex_torch.cache.reduce.Max(dim=1)[source]

Bases: Reduce

Max reduction over a single logical axis.

This operator reduces a rank-3 tensor

\[X \in \mathbb{R}^{B \times N \times D}\]

by taking the maximum along one of the inner dimensions, as configured by dim:

  • dim == 1: row-wise maximum over \(N\), producing

    \[Y \in \mathbb{R}^{B \times 1 \times D}, \qquad Y[b, 0, d] = \max_{0 \le n < N} X[b, n, d].\]
  • dim == 2: column-wise maximum over \(D\), producing

    \[Y \in \mathbb{R}^{B \times N \times 1}, \qquad Y[b, n, 0] = \max_{0 \le d < D} X[b, n, d].\]
Parameters:

dim (int, optional) – Logical reduction dimension. Must be 1 (reduce over \(N\)) or 2 (reduce over \(D\)). Default is 1.

class vortex_torch.cache.reduce.Min(dim=1)[source]

Bases: Reduce

Min reduction over a single logical axis.

This operator reduces a rank-3 tensor

\[X \in \mathbb{R}^{B \times N \times D}\]

by taking the minimum along one of the inner dimensions, as configured by dim:

  • dim == 1: row-wise minimum over \(N\), producing

    \[Y \in \mathbb{R}^{B \times 1 \times D}, \qquad Y[b, 0, d] = \min_{0 \le n < N} X[b, n, d].\]
  • dim == 2: column-wise minimum over \(D\), producing

    \[Y \in \mathbb{R}^{B \times N \times 1}, \qquad Y[b, n, 0] = \min_{0 \le d < D} X[b, n, d].\]
Parameters:

dim (int, optional) – Logical reduction dimension. Must be 1 (reduce over \(N\)) or 2 (reduce over \(D\)). Default is 1.

class vortex_torch.cache.reduce.L2Norm(dim=1)[source]

Bases: Reduce

L2-norm reduction (not RMS) over a single logical axis.

This operator reduces a rank-3 tensor

\[X \in \mathbb{R}^{B \times N \times D}\]

by computing an L2 norm along one of the inner dimensions, as configured by dim. The reduction is not normalized by the number of elements (it is an L2 norm, not an RMS):

  • dim == 1: row-wise L2 norm over \(N\), producing

    \[Y \in \mathbb{R}^{B \times 1 \times D}, \qquad Y[b, 0, d] = \sqrt{\sum_{n=0}^{N-1} X[b, n, d]^2}.\]
  • dim == 2: column-wise L2 norm over \(D\), producing

    \[Y \in \mathbb{R}^{B \times N \times 1}, \qquad Y[b, n, 0] = \sqrt{\sum_{d=0}^{D-1} X[b, n, d]^2}.\]

Notes

This is a pure L2 norm over the reduced axis:

\[\|v\|_2 = \sqrt{\sum_i v_i^2},\]

with no division by the number of elements. It should not be confused with RMS (root mean square).

Parameters:

dim (int, optional) – Logical reduction dimension. Must be 1 (reduce over \(N\)) or 2 (reduce over \(D\)). Default is 1.