vortex_torch.indexer.kron

Classes

Kron([dim])

Kronecker product over a configurable subset of the inner axes.

class Kron(dim=(1, 2))[source]

Bases: vOp

Kronecker product over a configurable subset of the inner axes.

The chosen inner axes (dim) are Kronecker-expanded; any other inner axis is multiplied elementwise (with broadcasting), and the leading \(S\) axis is always elementwise.

Math:

For \(X\in\mathbb{R}^{S\times x_1\times x_2}\) and \(Y\in\mathbb{R}^{S\times y_1\times y_2}\):

\[\begin{split}\begin{aligned} \text{dim}=(1,2):\quad & O_{s,\,i\,y_1+j,\,k\,y_2+l} = X_{s,i,k}\,Y_{s,j,l}, \\ \text{dim}=(1,):\quad & O_{s,\,i\,y_1+j,\,d} = X_{s,i,d}\,Y_{s,j,d}, \\ \text{dim}=(2,):\quad & O_{s,\,c,\,k\,y_2+l} = X_{s,c,k}\,Y_{s,c,l}. \end{aligned}\end{split}\]
__init__:

Kron(dim=(1, 2)) — inner axis/axes to expand, each 1 or 2 (non-listed axes must be equal or broadcastable).

__call__:

o = op(x, y, ctx=ctx)x [S, x_1, x_2], y [S, y_1, y_2]; an expanded axis has output size x.shape[a]*y.shape[a], a broadcast axis max(x.shape[a], y.shape[a]). Output is BATCHED iff both inputs are.

Parameters:

dim (int | Iterable[int])