vortex_torch.cache.elementwise_binary

Classes

Add([alpha, beta])

Weighted sum (AXPBY-style) of two tensors.

Elementwise_Binary([alpha, beta])

Binary elementwise operator dispatcher (e.g. Maximum / Minimum / AXPBY / Mul).

Maximum([alpha, beta])

Elementwise maximum of two tensors.

Minimum([alpha, beta])

Elementwise minimum of two tensors.

Multiply([alpha, beta])

Elementwise product of two tensors.

class vortex_torch.cache.elementwise_binary.Elementwise_Binary(alpha=1.0, beta=1.0)[source]

Bases: vOp

Binary elementwise operator dispatcher (e.g. Maximum / Minimum / AXPBY / Mul).

This class dispatches a family of binary elementwise operations on rank-3 tensors. The inputs are treated as

\[X, Y \in \mathbb{R}^{B \times N \times D},\]

where:

  • \(B\) is a leading batch-like axis (typically derived from the runtime context, e.g. max_new_tokens_per_batch * head_num),

  • \(N\) is a sequence or position dimension, and

  • \(D\) is a feature/channel dimension.

Broadcasting is supported on the last two dimensions:

  • \(N\) is broadcastable if x.shape[1] == y.shape[1], or one of them equals 1.

  • \(D\) is broadcastable if x.shape[2] == y.shape[2], or one of them equals 1.

For a given operation type op_type, the dispatcher applies a scalar function

\[f(x, y; \alpha, \beta, \text{op_type})\]

pointwise to produce

\[Z[b, n, d] = f\bigl(X[b, n', d'], Y[b, n'', d'']; \alpha, \beta, \text{op_type}\bigr),\]

where \((n', d')\) and \((n'', d'')\) are the broadcasted indices derived from \((n, d)\).

Dispatch is based on the triplet of tensor formats (x_format, y_format, o_format) and a registry mapping:

(x_format, y_format, o_format) -> (impl, resolved_output_format)

Policy

  • If output is None:

    • profile() selects an implementation with o_format == FORMAT.RAGGED, i.e. a key (x_fmt, y_fmt, FORMAT.RAGGED) in _impl_map.

    • An internal buffer of shape [B, N_out, D_out] is allocated, where

      \[N_{\text{out}} = \max(N_x, N_y), \quad D_{\text{out}} = \max(D_x, D_y),\]

      and \(B\) is derived from the runtime context.

  • If output is provided:

    • profile() requires an exact implementation key for (x_fmt, y_fmt, o_fmt).

    • The shape of output must match the broadcasted (N_out, D_out).

    • Device consistency is enforced for x, y and output.

_impl_map

Dispatch table keyed by (x_format, y_format, o_format). Each entry maps to (callable_impl, resolved_output_format).

Type:

Dict[Tuple[FORMAT, FORMAT, FORMAT], Tuple[Callable, FORMAT]]

alpha

Scalar parameter used by certain binary ops (for example, as a multiplicative or additive coefficient).

Type:

float

beta

Scalar parameter used by certain binary ops.

Type:

float

op_type

Enum value describing the specific binary operation (e.g. maximum, minimum, AXPBY, multiply).

Type:

Optional[ElementwiseBinaryOpType]

impl

The resolved implementation selected during profile().

Type:

Optional[Callable]

output_format

The output tensor format as determined in profile().

Type:

Optional[FORMAT]

output_buffer

Internal output buffer allocated when output is None.

Type:

Optional[torch.Tensor]

profile(x, y, output, loc, ctx)[source]

Validate inputs, resolve the implementation and output format, and optionally allocate an internal output buffer.

The input tensors x and y are expected to have logical shape [B, N_x, D_x] and [B, N_y, D_y] respectively, with broadcasting allowed on N and D:

\[N_{\text{out}} = \max(N_x, N_y), \quad D_{\text{out}} = \max(D_x, D_y).\]

The auxiliary tensor loc carries per-position metadata used by the implementation (for example, indices or segment offsets); its shape and semantics are kernel-defined.

Parameters:
  • x (vTensor) – Left-hand operand with logical shape [B, N_x, D_x].

  • y (vTensor) – Right-hand operand with logical shape [B, N_y, D_y].

  • output (Optional[vTensor]) – Optional preallocated output tensor. If None, an internal buffer with shape [B, N_out, D_out] is allocated using ctx.max_new_tokens_per_batch * ctx.head_num for B and a RAGGED-output implementation is selected. If not None, this tensor must have rank 3, broadcasted shape [B_out, N_out, D_out] and a format compatible with _impl_map.

  • loc (torch.Tensor) – Auxiliary tensor carrying per-position metadata used by the implementation.

  • ctx (Context) – Execution context that provides the runtime value of B and is used for auxiliary memory accounting.

Returns:

A vTensor view representing the resolved output: either the provided output or an internally allocated buffer.

Return type:

vTensor

Raises:

AssertionError – If types, ranks, broadcast conditions, formats, shapes, or devices are incompatible, or if no implementation is found in _impl_map.

execute(x, y, output, loc, ctx)[source]

Execute the selected binary elementwise implementation.

This method assumes that profile() has already selected an implementation and, if needed, allocated an internal output buffer.

Parameters:
  • x (torch.Tensor) – Plain left-hand operand tensor, with shape compatible with the vTensor validated in profile().

  • y (torch.Tensor) – Plain right-hand operand tensor, with shape compatible with the vTensor validated in profile().

  • output (Optional[torch.Tensor]) – Optional preallocated output tensor. If None, the internal buffer created during profile() will be used.

  • loc (torch.Tensor) – Auxiliary tensor carrying per-position metadata required by the implementation.

  • ctx (Context) – Execution context forwarded to the implementation.

Returns:

The output tensor written by the implementation: either the provided output or the internal buffer.

Return type:

torch.Tensor

Raises:

AssertionError – If profile() has not been called and no implementation or internal buffer is available.

Parameters:
class vortex_torch.cache.elementwise_binary.Maximum(alpha=1.0, beta=1.0)[source]

Bases: Elementwise_Binary

Elementwise maximum of two tensors.

This operator applies, pointwise, the scalar function

\[f(x, y) = \max(x, y).\]

Given two input tensors

\[X, Y \in \mathbb{R}^{B \times N \times D},\]

with broadcasting allowed on the N and D dimensions, the output tensor \(Z\) is defined by

\[Z[b, n, d] = \max\bigl(X[b, n', d'], Y[b, n'', d'']\bigr),\]

where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).

Parameters:
  • alpha (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.

  • beta (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.

class vortex_torch.cache.elementwise_binary.Minimum(alpha=1.0, beta=1.0)[source]

Bases: Elementwise_Binary

Elementwise minimum of two tensors.

This operator applies, pointwise, the scalar function

\[f(x, y) = \min(x, y).\]

Given two input tensors

\[X, Y \in \mathbb{R}^{B \times N \times D},\]

with broadcasting allowed on the N and D dimensions, the output tensor \(Z\) is defined by

\[Z[b, n, d] = \min\bigl(X[b, n', d'], Y[b, n'', d'']\bigr),\]

where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).

Parameters:
  • alpha (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.

  • beta (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.

class vortex_torch.cache.elementwise_binary.Add(alpha=1.0, beta=1.0)[source]

Bases: Elementwise_Binary

Weighted sum (AXPBY-style) of two tensors.

This operator applies, pointwise, the scalar function

\[f(x, y; \alpha, \beta) = \alpha x + \beta y.\]

Given two input tensors

\[X, Y \in \mathbb{R}^{B \times N \times D},\]

with broadcasting allowed on the N and D dimensions, the output tensor \(Z\) is defined by

\[Z[b, n, d] = \alpha \, X[b, n', d'] + \beta \, Y[b, n'', d''],\]

where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).

Parameters:
  • alpha (float, optional) – Coefficient \(\alpha\) applied to the first input tensor. Default is 1.

  • beta (float, optional) – Coefficient \(\beta\) applied to the second input tensor. Default is 1.

class vortex_torch.cache.elementwise_binary.Multiply(alpha=1.0, beta=1.0)[source]

Bases: Elementwise_Binary

Elementwise product of two tensors.

This operator applies, pointwise, the scalar function

\[f(x, y) = x \cdot y.\]

Given two input tensors

\[X, Y \in \mathbb{R}^{B \times N \times D},\]

with broadcasting allowed on the N and D dimensions, the output tensor \(Z\) is defined by

\[Z[b, n, d] = X[b, n', d'] \cdot Y[b, n'', d''],\]

where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).

Parameters:
  • alpha (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.

  • beta (float, optional) – Unused for this operation. Present only to match the common Elementwise_Binary interface. Default is 1.