This is a brief explanation and a cookbook for using numpy.einsum, which lets us use Einstein notation to evaluate operations on multi-dimensional arrays. The focus here is mostly on einsum's explicit mode (with -> and output dimensions explicitly specified in the subscript string) and use cases common in ML papers, though I'll also briefly touch upon other patterns.

Basic use case - matrix multiplication

Let's start with a basic demonstration: matrix multiplication using einsum. Throughout this post, A and B will be these matrices:

>>> A = np.arange(6).reshape(2,3)

>>> A
array([[0, 1, 2],
       [3, 4, 5]])

>>> B = np.arange(12).reshape(3,4)+1

>>> B
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12]])

The shapes of A and B let us multiply A @ B to get a (2,4) matrix. This can also be done with einsum, as follows:

>>> np.einsum('ij,jk->ik', A, B)
array([[ 23,  26,  29,  32],
       [ 68,  80,  92, 104]])

The first parameter to einsum is the subscript string, which describes the operation to perform on the following operands. Its format is a comma-separated list of inputs, followed by -> and an output. An arbitrary number of positional operands follows; they match the comma-separated inputs specified in the subscript. For each input, its shape is a sequence of dimension labels like i (any single letter).

In our example, ij refers to the matrix A - denoting its shape as (i,j), and jk refers to the matrix B - denoting its shape as (j,k). While in the subscript these dimension labels are symbolic, they become concrete when einsum is invoked with actual operands. This is because the shapes of the operands are known at that point.

The following is a simplified mental model of what einsum does (for a more complete description, read An instructional implementation of einsum):

  • The output part of the subscript specifies the shape of the output array, expressed in terms of the input dimension labels.
  • Whenever a dimension label is repeated in the input and absent in the output - it is contracted (summed). In our example, j is repeated (and doesn't appear in the output), so it's contracted: each output element [ik] is a dot product of the i'th row of the first input with the k'th column of the second input.

We can easily transpose the output, by flipping its shape:

>>> np.einsum('ij,jk->ki', A, B)
array([[ 23,  68],
       [ 26,  80],
       [ 29,  92],
       [ 32, 104]])

This is equivalent to (A @ B).T.

When reading ML papers, I find that even for such simple cases as basic matrix multiplication, the einsum notation is often preferred to the plain @ (or its function form like np.dot and np.matmul). This is likely because the einsum approach is self-documenting, helping the writer reason through the dimensions more explicitly.

Batched matrix multiplication

Using einsum instead of @ for matmuls as a documentation prop starts making even more sense when the ndim [1] of the inputs grows. For example, we may want to perform matrix multiplication on a whole batch of inputs within a single operation. Suppose we have these arrays:

>>> Ab = np.arange(6*6).reshape(6,2,3)
>>> Bb = np.arange(6*12).reshape(6,3,4)

Here 6 is the batch dimension. We're multiplying a batch of six (2,3) matrices by a batch of six (3,4) matrices; each matrix in Ab is multiplied by a corresponding matrix in Bb. The result is shaped (6,2,4).

We can perform batched matmul by doing Ab @ Bb - in Numpy this just works: the contraction happens between the last dimension of the first array and the penultimate dimension of the second array. This is repeated for all the dimensions preceding the last two. The shape of the output is (6,2,4), as expected.

With the einsum notation, we can do the same, but in a way that's more self-documenting:

>>> np.einsum('bmd,bdn->bmn', Ab, Bb)

This is equivalent to Ab @ Bb, but the subscript string lets us name the dimensions with single letters and makes it easier to follow w.r.t. what's going on. For example, in this case b may stand for batch, m and n may stand for sequence lengths and d could be some sort of model dimension/depth.

Note: while b is repeated in the inputs of the subscript, it also appears in the output; therefore it's not contracted.

Ordering output dimensions

The order of output dimensions in the subscript of einsum allows us to do more than just matrix multiplications; we can also transpose arbitrary dimensions:

>>> Bb.shape
(6, 3, 4)

>>> np.einsum('ijk->kij', Bb).shape
(4, 6, 3)

This capability is commonly combined with matrix multiplication to specify exactly the order of dimensions in a multi-dimensional batched array multiplication. The following is an example taken directly from the Fast Transformer Decoding paper by Noam Shazeer.

In the section on batched multi-head attention, the paper defines the following arrays:

  • M: a tensor with shape (b,m,d) (batch, sequence length, model depth)
  • P_k: a tensor with shape (h,d,k) (number of heads, model depth, head size for keys)

Let's define some dimension size constants and random arrays:

>>> m = 4; d = 3; k = 6; h = 5; b = 10
>>> Pk = np.random.randn(h, d, k)
>>> M = np.random.randn(b, m, d)

The paper performs an einsum to calculate all the keys in one operation:

>>> np.einsum('bmd,hdk->bhmk', M, Pk).shape
(10, 5, 4, 6)

Note that this involves both contraction (of the d dimension) and ordering of the outputs so that batch comes before heads. Theoretically, we could reverse this order by doing:

>>> np.einsum('bmd,hdk->hbmk', M, Pk).shape
(5, 10, 4, 6)

And indeed, we could have the output in any order. Obviously, bhmk is the one that makes sense for the specific operation at hand. It's important to highlight the readability of the einsum approach as opposed to a simple M @ Pk, where the dimensions involved are much less clear [2].

Contraction over multiple dimensions

More than one dimension can be contracted in a single einsum, as demonstrated by another example from the same paper:

>>> b = 10; n = 4; d = 3; v = 6; h = 5
>>> O = np.random.randn(b, h, n, v)
>>> Po = np.random.randn(h, d, v)
>>> np.einsum('bhnv,hdv->bnd', O, Po).shape
(10, 4, 3)

Both h and v appear in both inputs of the subscript but not in the output. Therefore, both these dimensions are contracted - each element of the output is a sum across both the h and v dimensions. This would be much more cumbersome to achieve without einsum!

Transposing inputs

When specifying the inputs to einsum, we can transpose them by reordering the dimensions. Recall our matrix A with shape (2,3); we can't multiply A by itself - the shapes don't match, but we can multiply it by its own transpose as in A @ A.T. With einsum, we can do this as follows:

>>> np.einsum('ij,kj->ik', A, A)
array([[ 5, 14],
       [14, 50]])

Note the order of dimensions in the second input of the subscript: kj instead of jk as before. Since j is still the label repeated in inputs but omitted in the output, it's the one being contracted.

More than two arguments

einsum supports an arbitrary number of inputs; suppose we want to chain-multiply our A and B with this array C:

>>> C = np.arange(20).reshape(4, 5)

We get:

>>> A @ B @ C
array([[ 900, 1010, 1120, 1230, 1340],
       [2880, 3224, 3568, 3912, 4256]])

With einsum, we do it like this:

>>> np.einsum('ij,jk,kp->ip', A, B, C)
array([[ 900, 1010, 1120, 1230, 1340],
       [2880, 3224, 3568, 3912, 4256]])

Here as well, I find the explicit dimension names a nice self-documentation feature.

An instructional implementation of einsum

The simplified mental model of how einsum works presented above is not entirely correct, though it's definitely sufficient to understand the most common use cases.

I read a lot of "how einsum works" documents online, and unfortunately they all suffer from similar issues; to put it generously, at the very least they're incomplete.

What I found is that implementing a basic version of einsum is easy; and that, moreover, this implementation serves as a much better explanation and mental model of how einsum works than other attempts [3]. So let's get to it.

We'll use the basic matrix multiplication as a guiding example: 'ij,jk->ik'.

This calculation has two inputs; so let's start by writing a function that takes two arguments [4]:

def calc(__a, __b):

The labels in the subscript specify the dimensions of these inputs, so let's define the dimension sizes explicitly (and also assert that sizes are compatible when a label is repeated in multiple inputs):

i_size = __a.shape[0]
j_size = __a.shape[1]
assert j_size == __b.shape[0]
k_size = __b.shape[1]

The output shape is (i,k), so we can create an empty output array:

out = np.zeros((i_size, k_size))

And generate a loop over its every element:

for i in range(i_size):
    for k in range(k_size):
        ...
return out

Now, what goes into this loop? It's time to look at the inputs in the subscript. Since there's a contraction on the j label, this means summation over this dimension:

for i in range(i_size):
    for k in range(k_size):
        for j in range(j_size):
            out[i, k] += __a[i, j] * __b[j, k]
return out

Note how we access out, __a and __b in the loop body; this is derived directly from the subscript 'ij,jk->ik'. In fact, this is how the einsum came to be from Einstein notation - more on this later on.

As another example of how to reason about einsum using this approach, consider the subscript from Contraction over multiple dimensions:

'bhnv,hdv->bnd'

Straight away, we can write out the assignment to the output, following the subscript:

out[b, n, d] += __a[b, h, n, v] * __b[h, d, v]

All that's left is figure out the loops. As discussed earlier, the outer loops are over the output dimensions, with two additional inner loops for the contracted dimensions in the input (v and h in this case). Therefore, the full implementation (omitting the assignments of *_size variables and dimension checks) is:

for b in range(b_size):
    for n in range(n_size):
        for d in range(d_size):
            for v in range(v_size):
                for h in range(h_size):
                    out[b, n, d] += __a[b, h, n, v] * __b[h, d, v]

What happens when the einsum subscript doesn't have any contracted dimension? In this case, there's no summation loop; the outer loops (assigning each element of the output array) are simply assigning the product of the appropriate input elements. Here's an example: 'i,j->ij'. As before, we start by setting up dimension sizes and the output array, and then a loop over each output element:

def calc(__a, __b):
    i_size = __a.shape[0]
    j_size = __b.shape[0]

    out = np.zeros((i_size, j_size))

    for i in range(i_size):
        for j in range(j_size):
            out[i, j] = __a[i] * __b[j]
    return out

Since there's no dimension in the input that doesn't appear in the output, there's no summation. The result of this computation is the outer product between two 1D input arrays.

I placed a well-documented implementation of this translation on GitHub. The function translate_einsum takes an einsum subscript and emits the text for a Python function that implements it.

Einstein notation

This notation is named after Albert Einstein because he introduced it to physics in his seminal 1916 paper on general relativity. Einstein was dealing with cumbersome nested sums to express operations on tensors and used this notation for brevity.

In physics, tensors typically have both subscripts and superscripts (for covariant and contravariant components), and it's common to encounter systems of equations like this:

\[\begin{align*} B^1=a_{11}A^1+a_{12}A^2+a_{13}A^3=\sum_{j=1}^{3} a_{ij}A^j\\ B^2=a_{21}A^1+a_{22}A^2+a_{23}A^3=\sum_{j=1}^{3} a_{2j}A^j\\ B^3=a_{31}A^1+a_{32}A^2+a_{33}A^3=\sum_{j=1}^{3} a_{3j}A^j\\ \end{align*}\]

We can collapse this into a single sum, using a variable i:

\[B^{i}=\sum_{j=1}^{3} a_{ij}A^j\]

And observe that since j is duplicated inside the sum (once in a subscript and once in a superscript), we can write this as:

\[B^{i}=a_{ij}A^j\]

Where the sum is implied; this is the core of Einstein notation. An observant reader will notice that the original system of equations can easily be expressed as matrix-vector multiplication, but keep a couple of things in mind:

  1. Matrix notation only became popular in physics after Einstein's work on general relativity (in fact, it was Werner Heisenberg who first introduced it in 1925).
  2. Einstein notation extends to any number of dimensions. Matrix notation is useful for 2D, but much more difficult to visualize and work with in higher dimensions. In 2D, matrix notation is equivalent to Einstein's.

It should be easy to see the equivalence between this notation and the einsum subscripts discussed in this post. The implicit mode of einsum is even closer to Einstein notation conceptually.

Implicit mode einsum

In implicit mode einsum, the output specification (-> and the labels following it) doesn't exist. Instead, the output shape is inferred from the input labels. For example, here's 2D matrix multiplication:

>>> np.einsum('ij,jk', A, B)
array([[ 23,  26,  29,  32],
       [ 68,  80,  92, 104]])

In implicit mode, the lexicographic order of labels in each input matters, as it determines the order of dimensions in the output. For example, if we want to (A @ B).T, we can do:

>>> np.einsum('ij,jh', A, B)
array([[ 23,  68],
       [ 26,  80],
       [ 29,  92],
       [ 32, 104]])

Since h precedes i in lexicographic order, this is equivalent to the explicit subscript 'ij,jh->hi, whereas the original implicit matmul subscript is equivalent to 'ih,jk->ik'.

Implicit mode isn't used much in ML code and papers, as far as I can tell. From my POV, compared to explicit mode it loses a lot of readability and gains very little savings in typing out the output labels.


[1]In the sense of numpy.ndim - the number of dimensions in the array. Alternatively this is sometimes called rank, but this is confusing because rank is already a name for something else in linear algebra.
[2]I personally believe that one of the biggest downsides of Numpy and all derived libraries (like JAX, PyTorch and TensorFlow) is that there's no way to annotate and check the shapes of operations. This makes some code much less readable than it could be. einsum mitigates this to some extent.
[3]First seen in this StackOverflow answer.
[4]The reason we use underscores here is to avoid collisions with potential dimension labels named a and b. Since we're doing code generation here, variable shadowing is a common issue; see hygienic macros for additional fun.