Knowledge Distillation

How do you shrink a giant model into a small one without a significant loss in performance? Knowledge Distillation (KD) provides a powerful solution. The core idea is to use a large, accurate pre-trained model (the "teacher") to train a smaller, faster model (the "student") that mimics the teacher’s behavior.

This technique has proven highly effective in practice. For example, DistilBERT was distilled from the BERT-base model, resulting in a 40% parameter reduction while retaining 95% of BERT's performance. Similarly, the original MobileNet was distilled from ResNet-50.

Today, Knowledge Distillation is widely deployed for a variety of tasks, including:

The Core Idea: Learning from "Soft" Predictions

Instead of training the student model solely on the hard ground-truth labels (e.g., "this image is a cat"), the student also learns from the "soft" predictions produced by the teacher. These soft predictions are the probability distribution over all classes that the teacher model outputs. This distribution contains richer information, revealing how the teacher "thinks" (e.g., "this image is 90% likely a cat, 8% a dog, and 2% a car").

To make this information more useful, the teacher's output logits are typically softened using a temperature scaling parameter, $T$.

$$ \text{soft_targets} = \text{softmax}\left(\frac{\text{logits}}{T}\right) $$

The temperature $T$ controls the smoothness of the output distribution.

Using a higher temperature helps the student learn the teacher’s inductive biases and provides a much better training signal, especially when true labels are sparse or noisy.

The Distillation Loss Function

The student is trained by optimizing a composite loss function that combines two objectives. This encourages the student to match both the ground truth and the teacher's soft predictions.

  1. Classification Loss: A standard cross-entropy loss between the student's predictions and the hard ground-truth labels.
  2. Distillation Loss: A loss that measures the difference between the student's and teacher's soft predictions. The Kullback-Leibler (KL) divergence is typically used for this.

The final loss is a weighted average of these two components, controlled by a hyperparameter $\alpha$:

$$ \mathcal{L} = \alpha \cdot \mathcal{L}_{CE}(\text{student_logits, true_labels}) + (1 - \alpha) \cdot \mathcal{L}_{KL}(\text{student_logits}_T, \text{teacher_logits}_T) $$

Types of Knowledge Distillation

Response-Based Distillation (Logit Distillation)

This is the most common and straightforward form of KD, where only the final output logits from the teacher are used to train the student. While this method doesn't capture the teacher's internal reasoning process, its simplicity and effectiveness have made it very popular. Both DistilBERT and MobileNet used this approach. The distillation loss is the KL divergence between the temperature-scaled outputs of the student ($z_s$) and teacher ($z_t$).

$$ \mathcal{L}_{KD} = KL\left(\text{softmax}\left(\frac{z_s}{T}\right) \bigg|\bigg| \text{softmax}\left(\frac{z_t}{T}\right)\right) $$

Feature-Based Distillation

This approach goes deeper by forcing the student to mimic the teacher's intermediate feature representations from one or more layers, not just the final output.

Relation-Based Distillation

Here, the student learns from the relationships between different data points or layers as seen by the teacher, rather than the absolute values of the outputs.

Self-Distillation

An interesting variant where a model teaches itself. Typically, a larger teacher model is used to train a student of the exact same architecture, often leading to improved performance and robustness.