Source code for vortex_torch.indexer.save_load

import torch
from typing import Dict, Optional
from .context import Context
from ..abs import vTensor, FORMAT, vOp
from ..utils import Schedule

[docs] class Save(vOp): r""" Persist a per-page indexer value across decode steps (paired with :class:`Load`) by copying it into a preallocated cache field. :Math: .. math:: O \leftarrow X (a format/layout copy ``RAGGED`` → ``PAGED``; no arithmetic). :__init__: ``Save()`` — no arguments. :__call__: ``op(x, o, ctx=ctx)`` — ``x`` ``[S, D_0, D_1]`` (``RAGGED``) is written **in place** into the preallocated ``o`` (``PAGED``, matching ``D_0`` / ``D_1``). Returns nothing. :Note: write side of the persistent-state pattern; a flow that uses ``Save`` requires the engine to set ``disable_radix_cache=True``. """ # Dispatch table keyed by x_format -> resolved output format. _impl_map: Dict[FORMAT, FORMAT] = { FORMAT.RAGGED: FORMAT.PAGED, # Add more entries if you support other formats. } def __init__(self): super().__init__() self.output_format: Optional[FORMAT] = None self.schedule = Schedule.W # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> vTensor: r"""Trace-time: validate ``x`` / ``o`` (rank-3, matching ``D_0`` / ``D_1``, compatible formats), register the op, and return ``o`` as the output view (no buffer is allocated).""" prefix = self._prefix() # Type & rank checks assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}" assert isinstance(o, vTensor), f"{prefix}profile expects o to be vTensor, got {type(o)}" assert x.dim() == 3, f"{prefix}expected 3D x [S, D0, D1], got {tuple(x.shape)}" assert o.dim() == 3, f"{prefix}expected 3D o [S, D0, D1], got {tuple(o.shape)}" # Shape checks: D0/D1 must match (S may differ by layout; implementation handles it) assert x.shape[1] == o.shape[1], ( f"{prefix}expected matching D0: x.shape[1]={x.shape[1]} vs o.shape[1]={o.shape[1]}" ) assert x.shape[2] == o.shape[2], ( f"{prefix}expected matching D1: x.shape[2]={x.shape[2]} vs o.shape[2]={o.shape[2]}" ) # Dispatch by x format x_fmt = x._format assert x_fmt in self._impl_map, ( f"{prefix}no implementation for x_fmt={x_fmt}. " f"Available keys: {list(self._impl_map.keys())}" ) self.output_format = self._impl_map[x_fmt] # Output format must match the resolved format from dispatch assert o._format == self.output_format, ( f"{prefix}output format mismatch. Expected {self.output_format}, got {o._format}" ) # Device consistency assert x.device == o.device, ( f"{prefix}x and o must be on the same device " f"(x.device={x.device}, o.device={o.device})" ) # Save is a *side-effect writer*: the op stores back into a # caller-provided cache field. We intentionally do NOT claim # ``o.tensor_id`` as Save's producer in ``output_tensor_to_op_list`` # — if Load elsewhere in the graph reads the same cache field, it # must see the previous-step value, not Save's updated value, and # overriding the producer slot would create a Load → Save cycle # through the DAG. # # Instead we register the op's id in ``side_effect_op_ids``; the # compiler seeds its op-DFS from that set so Save survives DCE, # and the target tensor is promoted to a final output so the # subgraph emits a ``tl.store`` for it. save_op_id = len(ctx.op_list) ctx.op_list.append(self) ctx.op_to_input_tensor_list.append([x.tensor_id]) ctx.op_to_output_tensor_list.append([o.tensor_id]) ctx.side_effect_op_ids.append(save_op_id) return o
[docs] class Load(vOp): r""" Read back a per-page value persisted by :class:`Save` (the read side of the cross-decode-step persistent-state pattern). :Math: .. math:: Y \leftarrow X (a format/layout copy ``PAGED`` → ``RAGGED``; no arithmetic). :__init__: ``Load()`` — no arguments. :__call__: ``y = op(x, ctx=ctx)`` — ``x`` ``[S, D_0, D_1]`` (``PAGED``); returns a freshly-allocated ``RAGGED`` view of the same inner shape. :Note: read side of the persistent-state pattern (see :class:`Save`). """ # Dispatch table keyed by x_format -> resolved output format. _impl_map: Dict[FORMAT, FORMAT] = { FORMAT.PAGED: FORMAT.RAGGED, # Add more entries if you support other formats. } def __init__(self): super().__init__() self.output_format: Optional[FORMAT] = None self.output_buffer: Optional[torch.Tensor] = None self.schedule = Schedule.W # ---------------- profile ---------------- def profile(self, x: vTensor, ctx: Context) -> vTensor: r"""Trace-time: validate ``x`` ``[S, D_0, D_1]``, register the op, and return a freshly-allocated ``vTensor`` view of the loaded value (same inner shape, ``RAGGED``).""" prefix = self._prefix() # Type & rank checks assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}" assert x.dim() == 3, f"{prefix}expected 3D x [S, D0, D1], got {tuple(x.shape)}" # Dispatch by x format x_fmt = x._format assert x_fmt in self._impl_map, ( f"{prefix}no implementation for x_fmt={x_fmt}. " f"Available keys: {list(self._impl_map.keys())}" ) self.output_format = self._impl_map[x_fmt] # Pure-metadata vTensor with a fresh tensor_id, mirroring the # Softmax pattern so the compiler graph can track the buffer. D0, D1 = x.shape[1], x.shape[2] self.output_buffer = vTensor( shape=(0, D0, D1), dtype=ctx.vortex_dtype, device=x.device, _format=self.output_format, tensor_id=len(ctx.tensor_list), ) # Track in the context graph (same convention as Softmax / Conv1d). ctx.tensor_list.append(self.output_buffer) ctx.output_tensor_to_op_list.append(len(ctx.op_list)) ctx.op_list.append(self) ctx.op_to_input_tensor_list.append([x.tensor_id]) ctx.op_to_output_tensor_list.append([self.output_buffer.tensor_id]) return self.output_buffer