import torch
from typing import Optional
from .context import Context
from ..abs import vTensor, FORMAT, vOp
from ..utils import ReduceType, Schedule
[docs]
class Reduce(vOp):
r"""
Generic 1-D reduction over one axis of a rank-3 logical tensor.
:Math:
For input :math:`X\in\mathbb{R}^{N\times D_0\times D_1}` and a
per-axis reduction :math:`\rho` (mean / max / min / L2-norm / sum,
fixed by the subclass):
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \rho_{\,0 \le i < D_0}\, X_{n,i,d}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \rho_{\,0 \le j < D_1}\, X_{n,d,j}.
\end{aligned}
``dim=0`` collapses the packed leading axis to one row per
``(batch, kv\_head)``.
:__init__:
``Reduce(dim=1)`` — logical axis to reduce, one of ``0`` / ``1`` / ``2``.
:__call__:
``y = op(x, ctx=ctx)`` — ``x`` is ``[N, D_0, D_1]``; the reduced axis is
kept with size 1. For ``dim ∈ {1, 2}`` the output is ``BATCHED`` iff the
input is; ``dim=0`` requires ``RAGGED`` input and returns ``BATCHED``.
:Note:
Use a concrete subclass — :class:`Max`, :class:`Min`, :class:`Mean`,
:class:`L2Norm`, :class:`Sum`.
"""
def __init__(self, dim: int = 1):
super().__init__()
self.dim = dim
self.reduce_type: Optional[ReduceType] = None
self.output_format: Optional[FORMAT] = None
self.output_buffer: Optional[torch.Tensor] = None
# dim==0 reduces across the packed leading axis; the result is one
# summary per (batch, kv_head), so it can't fuse into a per-block
# workload kernel — schedule it standalone.
self.schedule = Schedule.S if dim == 0 else Schedule.W
prefix = self._prefix()
assert self.dim in (0, 1, 2), (
f"{prefix}__init__: dim must be 0, 1, or 2, got dim={self.dim}"
)
# ---------------- profile ----------------
def profile(self, x: vTensor, ctx: Context) -> vTensor:
r"""Trace-time: validate ``x`` (``[N, D_0, D_1]``), resolve the output
format, register the op, and return a ``vTensor`` view of the reduced
output (see the class docstring for shapes)."""
prefix = self._prefix()
# Type & rank checks
assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}"
assert x.dim() == 3, (
f"{prefix}expected 3D input [N, D0, D1], "
f"got ndim={x.dim()} shape={tuple(x.shape)}"
)
D0, D1 = x.shape[1], x.shape[2]
if self.dim == 0:
# Cross-row reduction: collapse the packed leading axis into one
# summary per (batch, kv_head). Input must be RAGGED (per-page or
# per-token); the compiler allocates a BATCHED buffer with leading
# dim ``ctx.max_bs * ctx.num_kv_heads`` (see indexer interface).
assert x._format == FORMAT.RAGGED, (
f"{prefix}dim=0 reduce requires RAGGED input, got {x._format}"
)
self.output_format = FORMAT.BATCHED
out_D0, out_D1 = D0, D1
else:
# Output is BATCHED iff the input is BATCHED; otherwise RAGGED.
self.output_format = (
FORMAT.BATCHED if x._format == FORMAT.BATCHED else FORMAT.RAGGED
)
out_D0 = 1 if self.dim == 1 else D0
out_D1 = 1 if self.dim == 2 else D1
# Pure-metadata vTensor — no torch.empty allocation needed.
self.output_buffer = vTensor(
shape=(0, out_D0, out_D1),
dtype=ctx.vortex_dtype,
device=x.device,
_format=self.output_format,
tensor_id=len(ctx.tensor_list),
)
# Track auxiliary memory and graph structure in the context
ctx.tensor_list.append(self.output_buffer) # Track the output buffer in the context
ctx.output_tensor_to_op_list.append(len(ctx.op_list)) # Map the output tensor to this operation
ctx.op_list.append(self) # Track this operation in the context
ctx.op_to_input_tensor_list.append([x.tensor_id]) # Map this op to its input tensors
ctx.op_to_output_tensor_list.append([self.output_buffer.tensor_id]) # Map this op to its output tensor
# Return vTensor view carrying the dispatched output format
return self.output_buffer
[docs]
class Max(Reduce):
r"""
Max reduction over one logical axis (a :class:`Reduce`).
:Math:
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \max_{0 \le i < D_0} X_{n,i,d}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \max_{0 \le j < D_1} X_{n,d,j}.
\end{aligned}
:__init__: ``Max(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
``2`` → :math:`D_1`).
:__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
(``dim=1``) or ``[N, D_0, 1]`` (``dim=2``).
"""
def __init__(self, dim: int = 1):
super().__init__(dim)
self.reduce_type = ReduceType.Max
[docs]
class Min(Reduce):
r"""
Min reduction over one logical axis (a :class:`Reduce`).
:Math:
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \min_{0 \le i < D_0} X_{n,i,d}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \min_{0 \le j < D_1} X_{n,d,j}.
\end{aligned}
:__init__: ``Min(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
``2`` → :math:`D_1`).
:__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
(``dim=1``) or ``[N, D_0, 1]`` (``dim=2``).
"""
def __init__(self, dim: int = 1):
super().__init__(dim)
self.reduce_type = ReduceType.Min
[docs]
class Mean(Reduce):
r"""
Mean reduction over one logical axis (a :class:`Reduce`).
:Math:
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \frac{1}{D_0}\sum_{i=0}^{D_0-1} X_{n,i,d}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \frac{1}{D_1}\sum_{j=0}^{D_1-1} X_{n,d,j}.
\end{aligned}
:__init__: ``Mean(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
``2`` → :math:`D_1`).
:__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
(``dim=1``) or ``[N, D_0, 1]`` (``dim=2``).
"""
def __init__(self, dim: int = 1):
super().__init__(dim)
self.reduce_type = ReduceType.Mean
[docs]
class L2Norm(Reduce):
r"""
L2-norm reduction over one logical axis (a :class:`Reduce`).
:Math:
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \Big(\sum_{i=0}^{D_0-1} X_{n,i,d}^2\Big)^{1/2}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \Big(\sum_{j=0}^{D_1-1} X_{n,d,j}^2\Big)^{1/2}.
\end{aligned}
:__init__: ``L2Norm(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
``2`` → :math:`D_1`).
:__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
(``dim=1``) or ``[N, D_0, 1]`` (``dim=2``).
"""
def __init__(self, dim: int = 1):
super().__init__(dim)
self.reduce_type = ReduceType.L2Norm
[docs]
class Sum(Reduce):
r"""
Sum reduction over one logical axis (a :class:`Reduce`).
:Math:
.. math::
\begin{aligned}
(\text{dim}=1):\quad & Y_{n,0,d} = \sum_{i=0}^{D_0-1} X_{n,i,d}, \\
(\text{dim}=2):\quad & Y_{n,d,0} = \sum_{j=0}^{D_1-1} X_{n,d,j}.
\end{aligned}
:__init__: ``Sum(dim=1)`` — axis to reduce (``1`` → :math:`D_0`,
``2`` → :math:`D_1`).
:__call__: ``y = op(x, ctx=ctx)`` — ``[N, D_0, D_1]`` → ``[N, 1, D_1]``
(``dim=1``) or ``[N, D_0, 1]`` (``dim=2``).
"""
def __init__(self, dim: int = 1):
super().__init__(dim)
self.reduce_type = ReduceType.Sum