Source code for vortex_torch.indexer.elementwise_binary

import torch
from typing import Tuple, Dict, Callable, Optional
from .context import Context
from ..abs import vTensor, as_vtensor, FORMAT, vOp
from .triton_kernels.elementwise_binary_impl import elementwise_binary_bpr, elementwise_binary_rrr, elementwise_binary_rpr
from ..utils import ElementwiseBinaryOpType

[docs] class Elementwise_Binary(vOp): r""" Binary elementwise dispatcher for rank-3 logical tensors ``[S, C, D]``. This operator dispatches to a binary elementwise implementation based on the pair of input formats ``(x._format, y._format)``. The logical output shape: - keeps the ``S`` axis from the runtime context (``ctx.max_num_pages``), and - follows broadcasting over the ``(C, D)`` axes. Scalar parameters ``alpha`` and ``beta`` can be used by certain binary operations (e.g. an ``axpby``-style op). Attributes ---------- _impl_map : Dict[Tuple[FORMAT, FORMAT], Tuple[Callable, FORMAT]] Dispatch table keyed by ``(x_format, y_format)``. Each entry maps to ``(callable_impl, resolved_output_format)``. alpha : float Scalar parameter used by some ops. Default is ``1.0``. beta : float Scalar parameter used by some ops. Default is ``1.0``. impl : Optional[Callable] The resolved implementation selected during :meth:`profile`. op_type : Optional[ElementwiseBinaryOpType] The operator type used by the implementation. output_format : Optional[FORMAT] The output tensor format as determined in :meth:`profile`. output_buffer : Optional[torch.Tensor] Preallocated output tensor buffer that stores the binary result. """ # Implementation dispatch table: keyed by (x_format, y_format). # Value: (callable_impl, resolved_output_format) _impl_map: Dict[Tuple[FORMAT, FORMAT], Tuple[Callable, FORMAT]] = { (FORMAT.RAGGED, FORMAT.RAGGED): (elementwise_binary_rrr, FORMAT.RAGGED), (FORMAT.BATCHED, FORMAT.PAGED): (elementwise_binary_bpr, FORMAT.RAGGED), (FORMAT.RAGGED, FORMAT.PAGED): (elementwise_binary_rpr, FORMAT.RAGGED), # Add more pairs as needed. } def __init__(self, alpha: float = 1.0, beta: float = 1.0): super().__init__() self.impl: Optional[Callable] = None self.op_type: Optional[ElementwiseBinaryOpType] = None self.alpha = alpha self.beta = beta self.output_format: Optional[FORMAT] = None self.output_buffer: Optional[torch.Tensor] = None # ---------------- profile ----------------
[docs] def profile(self, x: vTensor, y: vTensor, ctx: Context) -> vTensor: r""" Validate inputs, select implementation, allocate the output buffer, and return a ``vTensor`` view with the resolved output format. The dispatcher: - checks that ``x`` and ``y`` are rank-3 tensors of shape ``[S, C, D]`` - enforces broadcastability on the ``C`` and ``D`` dimensions - selects an implementation using ``(x._format, y._format)`` - allocates an output buffer with shape ``[S_ctx, C_out, D_out]`` where .. math:: C_{\text{out}} = \max(C_x, C_y), \quad D_{\text{out}} = \max(D_x, D_y), and ``S_ctx = ctx.max_num_pages``. Parameters ---------- x : vTensor Left-hand input tensor, rank-3, with logical shape ``[S, C, D]``. y : vTensor Right-hand input tensor, rank-3, with logical shape ``[S, C, D]``. ctx : Context Execution context providing the runtime ``S`` (``ctx.max_num_pages``) and auxiliary-memory accounting. Returns ------- vTensor A ``vTensor`` view wrapping the allocated output buffer, using the resolved output format from the dispatch table. Raises ------ AssertionError If types are not ``vTensor``, if ranks are not 3, if ``C``/``D`` are not broadcastable, if formats are unsupported, or if devices of ``x`` and ``y`` do not match. """ prefix = self._prefix() # Type checks assert isinstance(x, vTensor), f"{prefix}profile expects x to be vTensor, got {type(x)}" assert isinstance(y, vTensor), f"{prefix}profile expects y to be vTensor, got {type(y)}" # Rank & basic shape checks assert x.dim() == 3 and y.dim() == 3, ( f"{prefix}expected 3D inputs [S, C, D]; got x.ndim={x.dim()}, y.ndim={y.dim()}" ) # Broadcastability on C/D assert (x.shape[1] == y.shape[1] or x.shape[1] == 1 or y.shape[1] == 1), ( f"{prefix}dim-1 not broadcastable: x.shape={tuple(x.shape)}, y.shape={tuple(y.shape)}" ) assert (x.shape[2] == y.shape[2] or x.shape[2] == 1 or y.shape[2] == 1), ( f"{prefix}dim-2 not broadcastable: x.shape={tuple(x.shape)}, y.shape={tuple(y.shape)}" ) # Dispatch x_fmt, y_fmt = x._format, y._format key = (x_fmt, y_fmt) assert key in self._impl_map, ( f"{prefix}no implementation for (x_fmt={x_fmt}, y_fmt={y_fmt}). " f"Available: {list(self._impl_map.keys())}" ) self.impl, self.output_format = self._impl_map[key] # Device consistency assert x.device == y.device, ( f"{prefix}x and y must be on the same device " f"(x.device={x.device}, y.device={y.device})" ) # Broadcasted output (C, D) C_out = max(x.shape[1], y.shape[1]) D_out = max(x.shape[2], y.shape[2]) # Allocate output buffer on x.device with x.dtype S = ctx.max_num_pages self.output_buffer = torch.empty( (S, C_out, D_out), device=x.device, dtype=x.dtype, ) ctx.add_aux_memory(self.output_buffer) for t in [x, y]: if t._format == FORMAT.PAGED: ctx.add_aux_flops( t.shape[1] * t.shape[2] ) # Return vTensor view with dispatched output format return as_vtensor(self.output_buffer, self.output_format)
# ---------------- execute ----------------
[docs] def execute(self, x: torch.Tensor, y: torch.Tensor, ctx: Context) -> torch.Tensor: r""" Execute the selected binary elementwise implementation into the internal output buffer and return it. Expected implementation signature:: impl(x, y, output, op_type, alpha, beta, ctx) Parameters ---------- x : torch.Tensor Left-hand input tensor on the same device as ``y`` and the output buffer. y : torch.Tensor Right-hand input tensor on the same device as ``x`` and the output buffer. ctx : Context Execution context passed through to the underlying implementation. Returns ------- torch.Tensor The output tensor stored in ``self.output_buffer``. Raises ------ AssertionError If :meth:`profile` has not been called (no implementation or buffer), or if there is a device mismatch between inputs and output. """ prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" assert self.output_buffer is not None, f"{prefix}output buffer is None; did profile() run?" assert x.device == y.device == self.output_buffer.device, ( f"{prefix}device mismatch: " f"x={x.device}, y={y.device}, o={self.output_buffer.device}" ) self.impl(x, y, self.output_buffer, self.op_type, self.alpha, self.beta, ctx) return self.output_buffer
[docs] class Maximum(Elementwise_Binary): r""" Elementwise maximum between two tensors. This operator computes the pointwise maximum: .. math:: \operatorname{out}(x, y) = \max(x, y) Broadcasting over the ``(C, D)`` axes is supported as described in :class:`Elementwise_Binary`. Parameters ---------- alpha : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the maximum operation itself. Default is ``1.0``. beta : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the maximum operation itself. Default is ``1.0``. """ def __init__(self, alpha: float = 1.0, beta: float = 1.0): super().__init__(alpha, beta) self.op_type = ElementwiseBinaryOpType.Maximum
[docs] class Minimum(Elementwise_Binary): r""" Elementwise minimum between two tensors. This operator computes the pointwise minimum: .. math:: \operatorname{out}(x, y) = \min(x, y) Broadcasting over the ``(C, D)`` axes is supported as described in :class:`Elementwise_Binary`. Parameters ---------- alpha : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the minimum operation itself. Default is ``1.0``. beta : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the minimum operation itself. Default is ``1.0``. """ def __init__(self, alpha: float = 1.0, beta: float = 1.0): super().__init__(alpha, beta) self.op_type = ElementwiseBinaryOpType.Minimum
[docs] class Add(Elementwise_Binary): r""" Affine combination of two tensors. This operator computes a weighted sum of the two inputs: .. math:: \operatorname{out}(x, y) = \alpha x + \beta y With the defaults :math:`\alpha = 1` and :math:`\beta = 1`, this reduces to standard elementwise addition: .. math:: \operatorname{out}(x, y) = x + y Broadcasting over the ``(C, D)`` axes is supported as described in :class:`Elementwise_Binary`. Parameters ---------- alpha : float, optional Scalar multiplier for :math:`x`. Default is ``1.0``. beta : float, optional Scalar multiplier for :math:`y`. Default is ``1.0``. """ def __init__(self, alpha: float = 1.0, beta: float = 1.0): super().__init__(alpha, beta) self.op_type = ElementwiseBinaryOpType.Add
[docs] class Multiply(Elementwise_Binary): r""" Elementwise product of two tensors. This operator computes the pointwise product: .. math:: \operatorname{out}(x, y) = x \cdot y Broadcasting over the ``(C, D)`` axes is supported as described in :class:`Elementwise_Binary`. Parameters ---------- alpha : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the pure multiplication operation itself. Default is ``1.0``. beta : float, optional Scalar parameter forwarded to the binary kernel. It is not used by the pure multiplication operation itself. Default is ``1.0``. """ def __init__(self, alpha: float = 1.0, beta: float = 1.0): super().__init__(alpha, beta) self.op_type = ElementwiseBinaryOpType.Mul