import torch
from typing import Dict, Optional
from .context import Context
from ..abs import vTensor, FORMAT, vOp
from ..utils import Schedule
[docs]
class Softmax(vOp):
r"""
Segmented scaled softmax over the packed sequence axis.
:Math:
The leading axis is a packed concatenation of :math:`B` per-request
segments :math:`\mathcal{I}_b` (total length
:math:`S = \sum_b S_b`). For each segment and each channel
:math:`(d_0,d_1)`, softmax runs **within** the segment:
.. math::
Y_{s,d_0,d_1} = \frac{\exp(\text{scale}\cdot X_{s,d_0,d_1})}
{\sum_{s'\in\mathcal{I}_b}\exp(\text{scale}\cdot X_{s',d_0,d_1})},
\qquad s\in\mathcal{I}_b.
:__init__: ``Softmax(dim=0, scale=1.0)`` — ``dim`` must be ``0`` (the packed
S axis); ``scale`` multiplies the logits before the exponential.
:__call__: ``y = op(x, ctx=ctx)`` — ``x`` ``[S, D_0, D_1]`` → same shape.
:Note: ``RAGGED`` only.
"""
# Dispatch table keyed by x_format -> resolved output format.
_impl_map: Dict[FORMAT, FORMAT] = {
FORMAT.RAGGED: FORMAT.RAGGED,
# Extend with other formats if you add more kernels:
# FORMAT.PAGED: FORMAT.PAGED,
}
def __init__(self, dim: int = 0, scale: float = 1.0):
super().__init__()
self.dim = dim
self.scale = scale
self.output_format: Optional[FORMAT] = None
self.schedule = Schedule.S
# Validate dim at construction
prefix = self._prefix()
assert self.dim in (0,), f"{prefix}__init__: dim must be 0, got dim={self.dim}"
# ---------------- profile ----------------
def profile(self, x: vTensor, ctx: Context) -> vTensor:
r"""Trace-time: validate ``x`` ``[S, D_0, D_1]``, dispatch on
``x._format``, register the op, and return the output ``vTensor``
view (segmented softmax over ``dim=0``)."""
prefix = self._prefix()
# Type & rank checks
assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}"
assert x.dim() == 3, (
f"{prefix}expected 3D input [S_pack, D0, D1], "
f"got ndim={x.dim()} shape={tuple(x.shape)}"
)
# Dispatch by input format
x_fmt = x._format
assert x_fmt in self._impl_map, (
f"{prefix}no implementation for x_fmt={x_fmt}. "
f"Available keys: {list(self._impl_map.keys())}"
)
self.output_format = self._impl_map[x_fmt]
self.output_buffer = vTensor(
shape=(0, x.shape[1], x.shape[2]),
dtype=ctx.vortex_dtype,
device=x.device,
_format=self.output_format,
tensor_id=len(ctx.tensor_list),
)
# Track auxiliary memory and graph structure in the context
ctx.tensor_list.append(self.output_buffer) # Track the output buffer in the context
ctx.output_tensor_to_op_list.append(len(ctx.op_list)) # Map the output tensor to this operation
ctx.op_list.append(self) # Track this operation in the context
ctx.op_to_input_tensor_list.append([x.tensor_id]) # Map this op to its input tensors
ctx.op_to_output_tensor_list.append([self.output_buffer.tensor_id]) # Map this op to its output tensor
return self.output_buffer
[docs]
class Normalize(vOp):
r"""
Segmented :math:`L_2` normalization over the packed sequence axis.
:Math:
The leading axis is a packed concatenation of :math:`B` per-request
segments :math:`\mathcal{I}_b`. For each segment and each channel
:math:`(d_0,d_1)`, the values are divided by their segment-local
:math:`L_2` norm:
.. math::
Y_{s,d_0,d_1} = \frac{X_{s,d_0,d_1}}
{\sqrt{\sum_{s'\in\mathcal{I}_b} X_{s',d_0,d_1}^2}},
\qquad s\in\mathcal{I}_b.
:__init__: ``Normalize(dim=0)`` — ``dim`` must be ``0`` (the packed S axis).
:__call__: ``y = op(x, ctx=ctx)`` — ``x`` ``[S, D_0, D_1]`` → same shape.
:Note: ``RAGGED`` only.
"""
# Dispatch table keyed by x_format -> resolved output format.
_impl_map: Dict[FORMAT, FORMAT] = {
FORMAT.RAGGED: FORMAT.RAGGED,
# Extend with other formats if you add more kernels:
# FORMAT.PAGED: FORMAT.PAGED,
}
def __init__(self, dim: int = 0):
super().__init__()
self.dim = dim
self.output_format: Optional[FORMAT] = None
self.schedule = Schedule.S
# Validate dim at construction
prefix = self._prefix()
assert self.dim in (0,), f"{prefix}__init__: dim must be 0, got dim={self.dim}"
# ---------------- profile ----------------
def profile(self, x: vTensor, ctx: Context) -> vTensor:
r"""Trace-time: validate ``x`` ``[S, D_0, D_1]``, dispatch on
``x._format``, register the op, and return the output ``vTensor``
view (segmented :math:`L_2` normalization over ``dim=0``)."""
prefix = self._prefix()
# Type & rank checks
assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}"
assert x.dim() == 3, (
f"{prefix}expected 3D input [S_pack, D0, D1], "
f"got ndim={x.dim()} shape={tuple(x.shape)}"
)
# Dispatch by input format
x_fmt = x._format
assert x_fmt in self._impl_map, (
f"{prefix}no implementation for x_fmt={x_fmt}. "
f"Available keys: {list(self._impl_map.keys())}"
)
self.output_format = self._impl_map[x_fmt]
self.output_buffer = vTensor(
shape=(0, x.shape[1], x.shape[2]),
dtype=ctx.vortex_dtype,
device=x.device,
_format=self.output_format,
tensor_id=len(ctx.tensor_list),
)
# Track auxiliary memory and graph structure in the context
ctx.tensor_list.append(self.output_buffer) # Track the output buffer in the context
ctx.output_tensor_to_op_list.append(len(ctx.op_list)) # Map the output tensor to this operation
ctx.op_list.append(self) # Track this operation in the context
ctx.op_to_input_tensor_list.append([x.tensor_id]) # Map this op to its input tensors
ctx.op_to_output_tensor_list.append([self.output_buffer.tensor_id]) # Map this op to its output tensor
return self.output_buffer
[docs]
class Conv1d(vOp):
r"""
Segmented depth-wise causal 1-D convolution over the packed sequence axis.
:Math:
The leading axis is a packed concatenation of :math:`B` per-request
segments. Within each segment, every channel :math:`(d_0,d_1)` is
convolved by its own :math:`K`-tap causal filter
:math:`W\in\mathbb{R}^{K\times D_0\times D_1}`:
.. math::
Y_{s,d_0,d_1} = \sum_{k=0}^{K-1} W_{k,d_0,d_1}\, X_{s-k,\,d_0,d_1},
with :math:`X_{s-k}=0` for :math:`s-k` before the segment. The op
runs only on the mid-range :math:`[b_{\text{bos}},\,S_b-b_{\text{eos}})`
(``ctx.block_reserved_bos`` / ``block_reserved_eos``); reserved
BOS/EOS rows are neither read nor written.
:__init__: ``Conv1d(weight, dim=0, dtype=torch.bfloat16, device=None)`` —
``weight`` is a nested Python list of shape ``[K, D_0, D_1]`` (kernel
size :math:`K` = ``len(weight)``); ``dim`` must be ``0``.
:__call__: ``y = op(x, ctx=ctx)`` — ``x`` ``[S, D_0, D_1]`` → same shape;
``weight``'s inner dims must match ``(D_0, D_1)``.
:Note: ``RAGGED`` only.
"""
_impl_map: Dict[FORMAT, FORMAT] = {
FORMAT.RAGGED: FORMAT.RAGGED,
}
def __init__(
self,
weight: list,
dim: int = 0,
dtype: torch.dtype = torch.bfloat16,
device: Optional[torch.device] = None,
):
super().__init__()
assert isinstance(weight, list), (
f"Conv1d: weight must be a Python list, got {type(weight)}"
)
weight_tensor = torch.tensor(weight, dtype=dtype, device=device)
assert weight_tensor.dim() == 3, (
f"Conv1d: weight must be a 3D nested list [K, D0, D1], "
f"got shape {tuple(weight_tensor.shape)}"
)
self.dim = dim
self.weight = weight_tensor
self.kernel_size = weight_tensor.shape[0]
self.output_format: Optional[FORMAT] = None
self.schedule = Schedule.S
prefix = self._prefix()
assert self.dim in (0,), f"{prefix}__init__: dim must be 0, got dim={self.dim}"
# ---------------- profile ----------------
def profile(self, x: vTensor, ctx: Context) -> vTensor:
r"""Trace-time: validate ``x`` ``[S, D_0, D_1]`` (weight inner dims
must match), migrate ``weight`` to ``x``'s device, register the op,
and return the output ``vTensor`` view."""
prefix = self._prefix()
assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}"
assert x.dim() == 3, (
f"{prefix}expected 3D input [S_pack, D0, D1], "
f"got ndim={x.dim()} shape={tuple(x.shape)}"
)
assert self.weight.shape[1] == x.shape[1] and self.weight.shape[2] == x.shape[2], (
f"{prefix}weight inner dims {tuple(self.weight.shape[1:])} "
f"must match input inner dims {tuple(x.shape[1:])}"
)
# ``__init__`` materializes ``self.weight`` on whatever device the
# user passed (defaulting to CPU when none was given), but the
# generated Triton kernel dereferences ``weight`` from the GPU.
# Migrate here — ``profile`` is the first call where the
# inference device is known. ``.to`` is a no-op if already on the
# right device, so it's safe to call on every re-profile.
if self.weight.device != x.device:
self.weight = self.weight.to(device=x.device)
if not self.weight.is_contiguous():
self.weight = self.weight.contiguous()
x_fmt = x._format
assert x_fmt in self._impl_map, (
f"{prefix}no implementation for x_fmt={x_fmt}. "
f"Available keys: {list(self._impl_map.keys())}"
)
self.output_format = self._impl_map[x_fmt]
self.output_buffer = vTensor(
shape=(0, x.shape[1], x.shape[2]),
dtype=ctx.vortex_dtype,
device=x.device,
_format=self.output_format,
tensor_id=len(ctx.tensor_list),
)
ctx.tensor_list.append(self.output_buffer)
ctx.output_tensor_to_op_list.append(len(ctx.op_list))
ctx.op_list.append(self)
ctx.op_to_input_tensor_list.append([x.tensor_id])
ctx.op_to_output_tensor_list.append([self.output_buffer.tensor_id])
return self.output_buffer