Convolutions are an important tool in modern deep neural networks (DNNs). This post is going to discuss some common types of convolutions, specifically regular and depthwise separable convolutions. My focus will be on the implementation of these operation, showing from-scratch Numpy-based code to compute them and diagrams that explain how things work.

Note that my main goal here is to explain how depthwise separable convolutions differ from regular ones; if you're completely new to convolutions I suggest reading some more introductory resources first.

The code here is compatible with TensorFlow's definition of convolutions in the tf.nn module. After reading this post, the documentation of TensorFlow's convolution ops should be easy to decipher.

Basic 2D convolution

The basic idea behind a 2D convolution is sliding a small window (usually called a "filter") over a larger 2D array, and performing a dot product between the filter elements and the corresponding input array elements at every position.

Here's a diagram demonstrating the application of a 3x3 convolution filter to a 6x6 array, in 3 different positions. W is the filter, and the yellow-ish array on the right is the result; the red square shows which element in the result array is being computed.

Single-channel 2D convolution

The topmost diagram shows the important concept of padding: what should we do when the window goes "out of bounds" on the input array. There are several options, with the following two being most common in DNNs:

  • Valid padding: in which only valid, in-bounds windows are considered. This also makes the output smaller than the input, because border elements can't be in the center of a filter (unless the filter is 1x1).
  • Same padding: in which we assume there's some constant value outside the bounds of the input (usually 0) and the filter is applied to every element. In this case the output array has the same size as the input array. The diagrams above depict same padding, which I'll keep using throughout the post.

There are other options for the basic 2D convolution case. For example, the filter can be moving over the input in jumps of more than 1, thus not centering on all elements. This is called stride, and in this post I'm always using stride of 1. Convolutions can also be dilated (or atrous), wherein the filter is expanded with gaps between every element. In this post I'm not going to discuss dilated convolutions and other options - there are plenty of resources on these topics online.

Implementing the 2D convolution

Here is a full Python implementation of the simple 2D convolution. It's called "single channel" to distinguish it from the more general case in which the input has more than two dimensions; we'll get to that shortly.

This implementation is fully self-contained, and only needs Numpy to work. All the loops are fully explicit - I specifically avoided vectorizing them for efficiency to maintain clarity:

def conv2d_single_channel(input, w):
    """Two-dimensional convolution of a single channel.

    Uses SAME padding with 0s, a stride of 1 and no dilation.

    input: input array with shape (height, width)
    w: filter array with shape (fd, fd) with odd fd.

    Returns a result with the same shape as input.
    """
    assert w.shape[0] == w.shape[1] and w.shape[0] % 2 == 1

    # SAME padding with zeros: creating a new padded array to simplify index
    # calculations and to avoid checking boundary conditions in the inner loop.
    # padded_input is like input, but padded on all sides with
    # half-the-filter-width of zeros.
    padded_input = np.pad(input,
                          pad_width=w.shape[0] // 2,
                          mode='constant',
                          constant_values=0)

    output = np.zeros_like(input)
    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            # This inner double loop computes every output element, by
            # multiplying the corresponding window into the input with the
            # filter.
            for fi in range(w.shape[0]):
                for fj in range(w.shape[1]):
                    output[i, j] += padded_input[i + fi, j + fj] * w[fi, fj]
    return output

Convolutions in 3 and 4 dimensions

The convolution computed above works in two dimensions; yet, most convolutions used in DNNs are 4-dimensional. For example, TensorFlow's tf.nn.conv2d op takes a 4D input tensor and a 4D filter tensor. How come?

The two additional dimensions in the input tensor are channel and batch. A canonical example of channels is color images in RGB format. Each pixel has a value for red, green and blue - three channels overall. So instead of seeing it as a matrix of triples, we can see it as a 3D tensor where one dimension is height, another width and another channel (also called the depth dimension).

Batch is somewhat different. ML training - with stochastic gradient descent - is often done in batches for performance; we train the model not on a single sample at a time, but a "batch" of samples, usually some power of two. Performing all the operations in tandem on a batch of data makes it easier to leverage the SIMD capabilities of modern processors. So it doesn't have any mathematical significance here - it can be seen as an outer loop over all operations, performing them for a set of inputs and producing a corresponding set of outputs.

For filters, the 4 dimensions are height, width, input channel and output channel. Input channel is the same as the input tensor's; output channel collects multiple filters, each of which can be different.

This can be slightly difficult to grasp from text, so here's a diagram:

Multi-channel 2D convolution

In the diagram and the implementation I'm going to ignore the batch dimension, since it's not really mathematically interesting. So the input image has three dimensions - in this diagram height and width are 8 and depth is 3. The filter is 3x3 with depth 3. In each step, the filter is slid over the input in two dimensions, and all of its elements are multiplied with the corresponding elements in the input. That's 3x3x3=27 multiplications added into the output element.

Note that this is different from a 3D convolution, where a filter is moved across the input in all 3 dimensions; true 3D convolutions are not widely used in DNNs at this time.

So, to reitarate, to compute the multi-channel convolution as shown in the diagram above, we compute each of the 64 output elements by a dot-product of the filter with the relevant parts of the input tensor. This produces a single output channel. To produce additional output channels, we perform the convolution with additional filters. So if our filter has dimensions (3, 3, 3, 4) this means 4 different 3x3x3 filters. The output will thus have dimensions 8x8 for the spatials and 4 for depth.

Here's the Numpy implementation of this algorithm:

def conv2d_multi_channel(input, w):
    """Two-dimensional convolution with multiple channels.

    Uses SAME padding with 0s, a stride of 1 and no dilation.

    input: input array with shape (height, width, in_depth)
    w: filter array with shape (fd, fd, in_depth, out_depth) with odd fd.
       in_depth is the number of input channels, and has the be the same as
       input's in_depth; out_depth is the number of output channels.

    Returns a result with shape (height, width, out_depth).
    """
    assert w.shape[0] == w.shape[1] and w.shape[0] % 2 == 1

    padw = w.shape[0] // 2
    padded_input = np.pad(input,
                          pad_width=((padw, padw), (padw, padw), (0, 0)),
                          mode='constant',
                          constant_values=0)

    height, width, in_depth = input.shape
    assert in_depth == w.shape[2]
    out_depth = w.shape[3]
    output = np.zeros((height, width, out_depth))

    for out_c in range(out_depth):
        # For each output channel, perform 2d convolution summed across all
        # input channels.
        for i in range(height):
            for j in range(width):
                # Now the inner loop also works across all input channels.
                for c in range(in_depth):
                    for fi in range(w.shape[0]):
                        for fj in range(w.shape[1]):
                            w_element = w[fi, fj, c, out_c]
                            output[i, j, out_c] += (
                                padded_input[i + fi, j + fj, c] * w_element)
    return output

An interesting point to note here w.r.t. TensorFlow's tf.nn.conv2d op. If you read its semantics you'll see discussion of layout or data format, which is NHWC by default. NHWC simply means the order of dimensions in a 4D tensor is:

  • N: batch
  • H: height (spatial dimension)
  • W: width (spatial dimension)
  • C: channel (depth)

NHWC is the default layout for TensorFlow; another commonly used layout is NCHW, because it's the format preferred by NVIDIA's DNN libraries. The code samples here follow the default.

Depthwise convolution

Depthwise convolutions are a variation on the operation discussed so far. In the regular 2D convolution performed over multiple input channels, the filter is as deep as the input and lets us freely mix channels to generate each element in the output. Depthwise convolutions don't do that - each channel is kept separate - hence the name depthwise. Here's a diagram to help explain how that works:

Depthwise 2D convolution

There are three conceptual stages here:

  1. Split the input into channels, and split the filter into channels (the number of channels between input and filter must match).
  2. For each of the channels, convolve the input with the corresponding filter, producing an output tensor (2D).
  3. Stack the output tensors back together.

Here's the code implementing it:

def depthwise_conv2d(input, w):
    """Two-dimensional depthwise convolution.

    Uses SAME padding with 0s, a stride of 1 and no dilation. A single output
    channel is used per input channel (channel_multiplier=1).

    input: input array with shape (height, width, in_depth)
    w: filter array with shape (fd, fd, in_depth)

    Returns a result with shape (height, width, in_depth).
    """
    assert w.shape[0] == w.shape[1] and w.shape[0] % 2 == 1

    padw = w.shape[0] // 2
    padded_input = np.pad(input,
                          pad_width=((padw, padw), (padw, padw), (0, 0)),
                          mode='constant',
                          constant_values=0)

    height, width, in_depth = input.shape
    assert in_depth == w.shape[2]
    output = np.zeros((height, width, in_depth))

    for c in range(in_depth):
        # For each input channel separately, apply its corresponsing filter
        # to the input.
        for i in range(height):
            for j in range(width):
                for fi in range(w.shape[0]):
                    for fj in range(w.shape[1]):
                        w_element = w[fi, fj, c]
                        output[i, j, c] += (
                            padded_input[i + fi, j + fj, c] * w_element)
    return output

In TensorFlow, the corresponding op is tf.nn.depthwise_conv2d; this op has the notion of channel multiplier which lets us compute multiple outputs for each input channel (somewhat like the number of output channels concept in conv2d).

Depthwise separable convolution

The depthwise convolution shown above is more commonly used in combination with an additional step to mix in the channels - depthwise separable convolution [1]:

Depthwise separable convolution

After completing the depthwise convolution, and additional step is performed: a 1x1 convolution across channels. This is exactly the same operation as the "convolution in 3 dimensions discussed earlier" - just with a 1x1 spatial filter. This step can be repeated multiple times for different output channels. The output channels all take the output of the depthwise step and mix it up with different 1x1 convolutions. Here's the implementation:

def separable_conv2d(input, w_depth, w_pointwise):
    """Depthwise separable convolution.

    Performs 2d depthwise convolution with w_depth, and then applies a pointwise
    1x1 convolution with w_pointwise on the result.

    Uses SAME padding with 0s, a stride of 1 and no dilation. A single output
    channel is used per input channel (channel_multiplier=1) in w_depth.

    input: input array with shape (height, width, in_depth)
    w_depth: depthwise filter array with shape (fd, fd, in_depth)
    w_pointwise: pointwise filter array with shape (in_depth, out_depth)

    Returns a result with shape (height, width, out_depth).
    """
    # First run the depthwise convolution. Its result has the same shape as
    # input.
    depthwise_result = depthwise_conv2d(input, w_depth)

    height, width, in_depth = depthwise_result.shape
    assert in_depth == w_pointwise.shape[0]
    out_depth = w_pointwise.shape[1]
    output = np.zeros((height, width, out_depth))

    for out_c in range(out_depth):
        for i in range(height):
            for j in range(width):
                for c in range(in_depth):
                    w_element = w_pointwise[c, out_c]
                    output[i, j, out_c] += depthwise_result[i, j, c] * w_element
    return output

In TensorFlow, this op is called tf.nn.separable_conv2d. Similarly to our implementation it takes two different filter parameters: depthwise_filter for the depthwise step and pointwise_filter for the mixing step.

Depthwise separable convolutions have become popular in DNN models recently, for two reasons:

  1. They have fewer parameters than "regular" convolutional layers, and thus are less prone to overfitting.
  2. With fewer parameters, they also require less operations to compute, and thus are cheaper and faster.

Let's examine the difference between the number of parameters first. We'll start with some definitions:

  • S: spatial dimension - width and height, assuming square inputs.
  • F: filter width and height, assuming square filter.
  • inC: number of input channels.
  • outC: number of output channels.

We also assume SAME padding as discussed above, so that the spatial size of the output matches the input.

In a regular convolution there are F*F*inC*outC parameters, because every filter is 3D and there's one such filter per output channel.

In depthwise separable convolutions there are F*F*inC parameters for the depthwise part, and then inC*outC parameters for the mixing part. It should be obvious that for a non-trivial outC, the sum of these two is significanly smaller than F*F*inC*outC.

Now on to computational cost. For a regular convolution, we perform F*F*inC operations at each position of the input (to compute the 2D convolution over 3 dimensions). For the whole input, the number of computations is thus F*F*inC*S*S and taking all the output channels we get F*F*inC*S*S*outC.

For depthwise separable convolutions we need F*F*inC*S*S* operations for the depthwise part; then we need S*S*inC*outC operations for the mixing part. Let's use some real numbers to get a feel for the difference:

We'll assume S=128, F=3, inC=3, outC=16. For regular convolution:

  • Parameters: 3*3*3*16 = 432
  • Computation cost: 3*3*3*128*128*16 = ~7e6

For depthwise separable convolution:

  • Parameters: 3*3*3+3*16 = 75
  • Computation cost: 3*3*3*128*128+128*128*3*16 = ~1.2e6

[1]The term separable comes from image processing, where spatially separable convolutions are sometimes used to save on computation resources. A spatial convolution is separable when the 2D convolution filter can be expressed as an outer product of two vectors. This lets us compute some 2D convolutions more cheaply. In the case of DNNs, the spatial filter is not necessarily separable but the channel dimension is separable from the spatial dimensions.