Deep learning model compression

This post covers model inference optimization or compression in breadth and hopefully depth as of March 2021. This includes engineering topics like model quantization and binarization, more research-oriented topics like knowledge distillation, as well as well-known-hacks.

Each year, larger and larger models are able to find methods for extracting signal from the noise in machine learning. In particular, language models get larger every day. These models are computationally expensive (in both runtime and memory), which can be both costly when served out to customers or too slow or large to function in edge environments like a phone.

Researchers and practitioners have come up with many methods for optimizing neural networks to run faster or with less memory usage. In this post I’m going to cover some of the state-of-the-art methods. If you know of another method you think should be included, I’m happy to add it. This has a slight PyTorch bias (haha) because I’m most familiar with it.


Quantization generally refers to taking a model with parameters trained at high precision (32 or 64 bits) and reducing the number of bits that each weight takes (for example down to 16, 8, or even fewer). In practice, this usually leads to a speedup of 2-4x (highest for nets with convolutions, in my experience).

Why does this work? It turns out that for deep networks to work, we don’t need highly precise values for the network’s weights. With proper hardware support, processing deep learning kernels (a fancy term for mathematical operations) using fewer bits can be faster and more memory efficient simply because there’s fewer bits to compute (torch.qint8 is 8 bits, and torch.float32 is 32 bits, so 4x smaller). Downsides: Depending on the level of quantization attempted, you might find that an operation you want (for example, a particular convolutional op or even something as simple as transpose) might not be implemented. Of course, as with all methods, you might find that accuracy drops off too much to be useful.

From the Tensorflow docs:

We generally recommend 16-bit floats for GPU acceleration and 8-bit integer for CPU execution.


PyTorch has support for special quantized tensors, which in their case corresponds to storing data in 8 or 16 bits. It’s important to understand one specific detail about how this works. If your network has a special structure that means that at some point all of the outputs are between 0 and 1 (e.g. from a sigmoid), then you might be able to choose a better, more specific quantization. This means that quantization needs to collect some data about how your network runs on representative inputs. In particular, most quantization happens via a method like round(x * scalar), where scalar is a learned parameter (akin to BatchNorm).

Quantization occasionally has gotchas - accumulating in higher precision data types is often more stable than using lower precision values. Picking the right precision for each operation can be nonobvious, so PyTorch has a torch.cuda.amp package to help you automatically cast different parts of your network to half precision (torch.float16) where it’s possible. If you want to do this manually, there’s some helpful tips on that page.

If you want more control or want to deploy to a non-CUDA environment, there are 3 levels of manual quantization (under the label “eager mode quantization”) that you can try, depending on why you’re trying to quantize and how much you’re willing to sweat:

  1. Dynamic quantization: This is the easiest method. Essentially, we store the weights of the network in the specified quantization, and then at run time, activations are dynamically converted to the quantized format, combined with the (quantized) weights, then written in memory at full precision. Then the next layer quantizes those, combines with the next quantized weights, and so on. Why does this happen? My understanding is that scalar can be dynamically determined from the data, which means this is a data-free method.

How do we do this in PyTorch? It’s short enough that we can write it down here:

# quantize the LSTM and Linear parts of our network
# and use the torch.qint8 type to quantize
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8

There are many more knobs you can turn to make this better for your model. See more details in this blog post.

  1. Static quantization: Runtime conversion to a full precision type and back is expensive. We can remove that if we know what the distribution of activations will be (by recording real data flowing through the network, as mentioned above). When you have access to data flowing through your network, PyTorch can also inspect your model and implement extra optimizations such as quantized operator fusion. Here’s an example of setting up the observers, running it with some data, and then exporting to a new statically quantized model:
# this is a default quantization config for mobile-based inference (ARM)
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
# this chain (conv + batchnorm + relu) is one of a few sequences 
# that are supported by the model fuser 
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
# insert observers
model_with_observers = torch.quantization.prepare(model_fused)
quantized_model = torch.quantization.convert(model_with_observers)
  1. Quantization-aware Training (QAT): if you’re familiar with neural network training, you know where this is going. If you tell the training method some fact about how the network is used, the network will adapt to this information. How does this work? During the forward and backward passes, the model’s activations are rounded to the picked quantization. This means the model gets gradients based on rounded values, which means it “adjusts” to its limited capacity. Very importantly, however, the actual backprop (i.e. the gradient descent of the weights) happens in full precision.

I’m leaving out the code example because this is a more involved method, but you can find a full example here. There are again many knobs.

Note! See the helpful tips under Model Preparation for Quantization here before using PyTorch quantization.

Quantization in other frameworks

PyTorch-based quantization might not necessarily work in other production environments. In particular, when converting to Apple’s CoreML format, you need to just use their quantization (which might be limited to just 16-bit quantization). When using edge devices, be careful to check that quantization is possible (in Apple’s case the hardware is already computing everything in fp16 on GPU, so you only save possibly the memory of the network’s weights).

Tensorflow has a similar set of steps as above, though the examples are focused on TFLite. Essentially, static and dynamic quantization are explained in the Post-training quantization page, and there’s a QAT page. I think the tradeoffs are very similar, though there’s always some feature mismatch between PyTorch and TF.

How far can we go?

Apparently down to 1 bit! There have been several attempts over the years to create binary neural networks if you want the most extreme version of the accuracy vs speed tradeoff. For the most part, these are still research projects rather than usable ideas, though XNOR-Net++ seems to have been implemented in PyTorch.

Reference links:


Pruning is removing some weights (i.e. connections) or entire neurons from a neural network after or during training. In practice we can often remove 90% of the parameters in large deep neural networks without significantly affecting model performance.

Why does this work?: Let’s imagine that your model is a fully connected neural network with just one hidden layer, such that the input is size 1024, the hidden size is 100, and the output is 20 dimensions. Then the number of parameters (without bias) is 104400. If there’s a neuron in the hidden layer that never fires (or is ignored downstream) then removing it from the network saves 1044 parameters. Why not just train the smaller network right away? The most compelling explanation is something called the lottery ticket hypothesis:

Any large network that trains successfully contains a subnetwork that is initialized such that - when trained in isolation - it can match the accuracy of the original network in at most the same number of training iterations.

Downside: Removing neurons or choosing a subnetwork is what I (and others) consider structured pruning. However, a lot of methods (including Tensorflow’s tensorflow_model_optimization toolkit at this time and PyTorch’s torch.nn.utils.prune) are focused on sparsifying model weights so that they are more compressible (what some call unstructured pruning). This means the matrices are the same size, but some values are set to 0. It’s currently unclear to me if this means that larger models can use less GPU memory (I think essentially not), but it can save you disk space. When sparse model support fully lands in the various frameworks (i.e you can multiply a sparse vector and a sparse matrix faster than the dense ones) you might be able to speed up inference as well.

For that reason, I’m not going to spend much time on unstructured pruning because it doesn’t seem that useful, but essentially you can prune during or after training, and you pick a certain target sparsity (e.g. 80% of the weights of your network will be zeroed out). However, there’s a lot of confusion in this area which makes it hard to recommend anything. Tensorflow has a a few guides on pruning both during and after training and PyTorch has a tutorial on pruning using some set of heuristcs after training.

In the space of structured pruning, there’s still active research and no clear API. We can pick a metric to compute a relevance score for each neuron, and then remove the ones that have the least information content. Metrics that might be useful here are the Shapley value, a Taylor approximation of the loss functions sensitivity to a neuron’s activation, or even a random neuron. The TorchPruner library implements some of these automatically for nn.Linear and convolutions (nn.Conv1D, nn.Conv2D, etc) modules. Another library Torch-Pruning has support for a few more operations. One of the most well-known older works in this area prunes filters from a convnet using the L1 norm of the filter’s weights. However, this is still an active area of research.

Fine tuning

In both cases, it’s standard to retrain the network after applying the pruning. The best method I know of is basically to reset the learning rate (learning rate rewinding) and start retraining the network. If you’d like, you can use weight rewinding, which is resetting the weights for the unpruned parts of the network to their value earlier in training (e.g. 1/3 trained weights). My intuition on this is that it’s essentially training the lottery ticket subnetwork now that we’ve identified it.

DeepSpeed & ZeRO-Offload

I’ll cover this in more detail once I have more experience with it, but essentially: DeepSpeed is a library that helps train large to extremely large models (e.g. 1bn+ parameters) faster and using less GPU memory. This works by exploiting smart parallelism and better caching. It comes in the form of an extension to PyTorch.

Knowledge distillation

Knowledge distillation is a method for creating smaller and more efficient models from large models. In NLP this has also been referred to as teacher-student methods, because the large model trains the student model. The reference work in this area is (Hinton et al., 2015).

In practice, suppose we have a classification task. Suppose our smaller student model is $f_\theta$, where $\theta$ is the set of parameters. We take either a large model or an ensemble of models (possibly even the same model trained with different initializations), and call it $F$ (we won’t worry about its parameters). Then we train the student network with the following loss:

$$\mathcal{L} = \sum_{i = 1}^n \operatorname{KL}\left(F(x_i), f_\theta(x_i)\right)$$

where $F(x_i)$ is the probability distribution over the labels created by passing example $x_i$ through the network. If you want, you can mix in a little bit of the regular cross entropy loss using the proper labels:

$$\mathcal{L} = \sum_{i = 1}^n \left(\operatorname{KL}\left(F(x_i), f_\theta(x_i)\right) - \beta \cdot \sum_{k = 1}^K y_i[k] \log f_\theta(x_i)[k]\right)$$

Note that this second term is just the KL divergence from the “true” distribution (i.e. the one-hot distribution from the labels) to the student model, since $y_i$ is one-hot.

Why does this work? There’s no consensus best opinion, as far as I know. The most compelling explanation I’ve read so far is that distillation is a form of rough data augmentation. I can recommend this paper: Towards Understanding Ensemble, Knowledge Distillation and Self-Distillation in Deep Learning, which is focused on the idea of multiple views. At risk of being long-winded, here’s a thought experiment that might explain:

Distillation thought experiment: Let’s say that we have a large teacher model that is trained to classify images (e.g. CIFAR-100). This model implicitly has a bunch of “feature-detectors” built-in, e.g. a set of convolutional filters that fire when pointy ears are seen, which increase the probability of a label like “cat”. Let’s say that there’s a training image of a Batman mask, labeled “mask”. The teacher model’s pointy ears filters might still fire, telling us that the model thinks that this looks 10% like a cat.

When the student model is trained to match the probability distribution of the teacher, because the distribution is 0.1 cat, it will still get a small signal that this image is catlike, which might help the student model recognize cats better than it could otherwise. If the student model was trained on just the true labels, it would have no idea that this Batman mask looks a bit like a cat.

A similar, but slightly different idea explains why ensembles of models (even the same architecture) might work well:

Ensembling thought experiment: Let’s say there’s 3 pictures of a cat in a dataset we’re using for image classification. Let’s say that image 1 has a cat with feature A (e.g. pointed ears), image 2 has feature B (e.g. whiskers), and image 3 has both A and B.

Then, let’s say the neural network learns feature A (e.g. by seeing image 1). When it sees image 3, that set of convolution filters will fire, and so the image will be correctly classified. So, there’ll be no gradient that tunes the net to recognize feature B, even though a good net would learn that.

Once a neural network has become good enough, its signal from some data points decreases.

Distillation in practice

Knowledge distillation is a very deep and wide research area, touching adversarial attacks, knowledge transfer, and privacy. Unfortunately I can’t cover those in any real detail, so I’ll leave them for a future day.

In practice, the method I’ve described above is called response-based distillation. There are also other forms of distillation, including feature-based and relation-based knowledge distillation, which are entire subfields based on what parts (or computations from) the student and teacher model we should tie together.

Furthermore, there’s a division between offline distillation (i.e. train the student after the teacher), online distillation (train the student and teacher together), and self-distillation (where the teacher model has the same architecture as the student). Together this makes it difficult to track distillation in practice; a set of adhoc model-specific techniques might be the best general recommendation.

In fact, (Cho & Hariharan, 2019) found that when the student model’s capacity is too low, using knowledge distillation will actually adversely affect training. They found that knowledge distillation papers rarely use ImageNet and so often don’t work well on difficult problems. Perplexingly, that paper and (Mirzadeh et al., 2019) found that better teacher models don’t always mean better distillation, and the farther the student and teacher model’s capacities are, the less effective distillation was. You can find a recent investigation in (Tang et al., 2021).

All in all, my understanding so far is that distillation is fairly difficult. You might be able to get some free performance points by training a student with a slightly smaller capacity and then using vanilla response-based offline distillation.


However, deep learning researchers have spent a lot of time distilling large models using model-specific methods, and if you need to gain some performance, you might be able to find a pre-trained distilled version of the large model you’re currently using. For example, in NLP, HuggingFace makes it easy to access both DistilBert and TinyBert. In computer vision, Facebook Research’s d2go has a bunch of pretrained mobile-ready models, and they’ve specialized some distillation methods in DeiT.

Well-Read Students Learn Better: On the Importance of Pre-training Compact Models makes a recommendation (with high quality ablation experiments) that for training BERT architectures, the best approach is:

  1. Pre-train a compact model architecture on the masked language model (MLM) objective developed by the original BERT papers (Devlin et al., 2018).
  2. Take a large task-specific teacher model (e.g. if the task is NLI, the output is a distribution over the 3 classes (entailment, contradiction, neutral)), and perform basic response-based offline distillation on the pre-trained compact model from step 1.
  3. Finally, if required, fine-tune the compact model from step 2 on the task-specific data (e.g. if the task is NER, train over the CoNLL 2003 dataset).

One of the best advantages of this method (which they call Pre-trained Distillation (PD)) is that it’s architecture-agnostic. I think if you are going to use a compact NLP model in practice it’s worth skimming the paper, especially section 6.