vortex_torch.cache.elementwise_binary¶
Classes
|
Weighted sum (AXPBY-style) of two tensors. |
|
Binary elementwise operator dispatcher (e.g. Maximum / Minimum / AXPBY / Mul). |
|
Elementwise maximum of two tensors. |
|
Elementwise minimum of two tensors. |
|
Elementwise product of two tensors. |
- class vortex_torch.cache.elementwise_binary.Elementwise_Binary(alpha=1.0, beta=1.0)[source]¶
Bases:
vOpBinary elementwise operator dispatcher (e.g. Maximum / Minimum / AXPBY / Mul).
This class dispatches a family of binary elementwise operations on rank-3 tensors. The inputs are treated as
\[X, Y \in \mathbb{R}^{B \times N \times D},\]where:
\(B\) is a leading batch-like axis (typically derived from the runtime context, e.g.
max_new_tokens_per_batch * head_num),\(N\) is a sequence or position dimension, and
\(D\) is a feature/channel dimension.
Broadcasting is supported on the last two dimensions:
\(N\) is broadcastable if
x.shape[1] == y.shape[1], or one of them equals1.\(D\) is broadcastable if
x.shape[2] == y.shape[2], or one of them equals1.
For a given operation type
op_type, the dispatcher applies a scalar function\[f(x, y; \alpha, \beta, \text{op_type})\]pointwise to produce
\[Z[b, n, d] = f\bigl(X[b, n', d'], Y[b, n'', d'']; \alpha, \beta, \text{op_type}\bigr),\]where \((n', d')\) and \((n'', d'')\) are the broadcasted indices derived from \((n, d)\).
Dispatch is based on the triplet of tensor formats
(x_format, y_format, o_format)and a registry mapping:(x_format, y_format, o_format) -> (impl, resolved_output_format)
Policy¶
If
outputisNone:profile()selects an implementation witho_format == FORMAT.RAGGED, i.e. a key(x_fmt, y_fmt, FORMAT.RAGGED)in_impl_map.An internal buffer of shape
[B, N_out, D_out]is allocated, where\[N_{\text{out}} = \max(N_x, N_y), \quad D_{\text{out}} = \max(D_x, D_y),\]and \(B\) is derived from the runtime context.
If
outputis provided:profile()requires an exact implementation key for(x_fmt, y_fmt, o_fmt).The shape of
outputmust match the broadcasted(N_out, D_out).Device consistency is enforced for
x,yandoutput.
- _impl_map¶
Dispatch table keyed by
(x_format, y_format, o_format). Each entry maps to(callable_impl, resolved_output_format).
- alpha¶
Scalar parameter used by certain binary ops (for example, as a multiplicative or additive coefficient).
- Type:
- op_type¶
Enum value describing the specific binary operation (e.g. maximum, minimum, AXPBY, multiply).
- Type:
Optional[ElementwiseBinaryOpType]
- output_buffer¶
Internal output buffer allocated when
outputisNone.- Type:
Optional[torch.Tensor]
- profile(x, y, output, loc, ctx)[source]¶
Validate inputs, resolve the implementation and output format, and optionally allocate an internal output buffer.
The input tensors
xandyare expected to have logical shape[B, N_x, D_x]and[B, N_y, D_y]respectively, with broadcasting allowed onNandD:\[N_{\text{out}} = \max(N_x, N_y), \quad D_{\text{out}} = \max(D_x, D_y).\]The auxiliary tensor
loccarries per-position metadata used by the implementation (for example, indices or segment offsets); its shape and semantics are kernel-defined.- Parameters:
x (vTensor) – Left-hand operand with logical shape
[B, N_x, D_x].y (vTensor) – Right-hand operand with logical shape
[B, N_y, D_y].output (Optional[vTensor]) – Optional preallocated output tensor. If
None, an internal buffer with shape[B, N_out, D_out]is allocated usingctx.max_new_tokens_per_batch * ctx.head_numforBand a RAGGED-output implementation is selected. If notNone, this tensor must have rank 3, broadcasted shape[B_out, N_out, D_out]and a format compatible with_impl_map.loc (torch.Tensor) – Auxiliary tensor carrying per-position metadata used by the implementation.
ctx (Context) – Execution context that provides the runtime value of
Band is used for auxiliary memory accounting.
- Returns:
A
vTensorview representing the resolved output: either the providedoutputor an internally allocated buffer.- Return type:
- Raises:
AssertionError – If types, ranks, broadcast conditions, formats, shapes, or devices are incompatible, or if no implementation is found in
_impl_map.
- execute(x, y, output, loc, ctx)[source]¶
Execute the selected binary elementwise implementation.
This method assumes that
profile()has already selected an implementation and, if needed, allocated an internal output buffer.- Parameters:
x (torch.Tensor) – Plain left-hand operand tensor, with shape compatible with the
vTensorvalidated inprofile().y (torch.Tensor) – Plain right-hand operand tensor, with shape compatible with the
vTensorvalidated inprofile().output (Optional[torch.Tensor]) – Optional preallocated output tensor. If
None, the internal buffer created duringprofile()will be used.loc (torch.Tensor) – Auxiliary tensor carrying per-position metadata required by the implementation.
ctx (Context) – Execution context forwarded to the implementation.
- Returns:
The output tensor written by the implementation: either the provided
outputor the internal buffer.- Return type:
torch.Tensor
- Raises:
AssertionError – If
profile()has not been called and no implementation or internal buffer is available.
- class vortex_torch.cache.elementwise_binary.Maximum(alpha=1.0, beta=1.0)[source]¶
Bases:
Elementwise_BinaryElementwise maximum of two tensors.
This operator applies, pointwise, the scalar function
\[f(x, y) = \max(x, y).\]Given two input tensors
\[X, Y \in \mathbb{R}^{B \times N \times D},\]with broadcasting allowed on the
NandDdimensions, the output tensor \(Z\) is defined by\[Z[b, n, d] = \max\bigl(X[b, n', d'], Y[b, n'', d'']\bigr),\]where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).
- Parameters:
alpha (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.beta (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.
- class vortex_torch.cache.elementwise_binary.Minimum(alpha=1.0, beta=1.0)[source]¶
Bases:
Elementwise_BinaryElementwise minimum of two tensors.
This operator applies, pointwise, the scalar function
\[f(x, y) = \min(x, y).\]Given two input tensors
\[X, Y \in \mathbb{R}^{B \times N \times D},\]with broadcasting allowed on the
NandDdimensions, the output tensor \(Z\) is defined by\[Z[b, n, d] = \min\bigl(X[b, n', d'], Y[b, n'', d'']\bigr),\]where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).
- Parameters:
alpha (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.beta (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.
- class vortex_torch.cache.elementwise_binary.Add(alpha=1.0, beta=1.0)[source]¶
Bases:
Elementwise_BinaryWeighted sum (AXPBY-style) of two tensors.
This operator applies, pointwise, the scalar function
\[f(x, y; \alpha, \beta) = \alpha x + \beta y.\]Given two input tensors
\[X, Y \in \mathbb{R}^{B \times N \times D},\]with broadcasting allowed on the
NandDdimensions, the output tensor \(Z\) is defined by\[Z[b, n, d] = \alpha \, X[b, n', d'] + \beta \, Y[b, n'', d''],\]where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).
- class vortex_torch.cache.elementwise_binary.Multiply(alpha=1.0, beta=1.0)[source]¶
Bases:
Elementwise_BinaryElementwise product of two tensors.
This operator applies, pointwise, the scalar function
\[f(x, y) = x \cdot y.\]Given two input tensors
\[X, Y \in \mathbb{R}^{B \times N \times D},\]with broadcasting allowed on the
NandDdimensions, the output tensor \(Z\) is defined by\[Z[b, n, d] = X[b, n', d'] \cdot Y[b, n'', d''],\]where \((n', d')\) and \((n'', d'')\) are the broadcasted indices corresponding to \((n, d)\).
- Parameters:
alpha (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.beta (float, optional) – Unused for this operation. Present only to match the common
Elementwise_Binaryinterface. Default is1.