Pruning & Knowledge Distillation

Large models are accurate but expensive. Pruning removes unnecessary weights. Knowledge distillation trains a small "student" model to mimic a large "teacher". Together they can produce models that are 10× smaller with only 5% accuracy loss.

📖 Covers: Unstructured Pruning · Structured Pruning · Magnitude Pruning · Knowledge Distillation · DistilBERT

Model Pruning

Neural networks are often over-parameterised — many weights contribute little to the output. Pruning identifies and removes (zeroes out) these redundant weights, making the model smaller and faster.

Unstructured Pruning

Remove individual weights based on magnitude. Results in sparse weight matrices.

Pros: High compression ratio, minimal accuracy loss

Cons: Sparse operations are hard to accelerate on standard GPUs

Structured Pruning

Remove entire neurons, attention heads, or layers. Resulting model is smaller and dense.

Pros: Real speedup on any hardware

Cons: Coarser granularity, more accuracy loss per parameter removed

Python · PyTorch Magnitude Pruning
import torch.nn.utils.prune as prune

model = MyModel()

# Prune 30% of connections in a layer by magnitude
prune.l1_unstructured(model.fc1, name='weight', amount=0.3)

# Check sparsity
total = model.fc1.weight.nelement()
zeros = torch.sum(model.fc1.weight == 0).item()
print(f"Sparsity: {zeros/total:.1%}")  # → Sparsity: 30.0%

# Make pruning permanent (remove the mask)
prune.remove(model.fc1, 'weight')

Knowledge Distillation

Instead of compressing the model itself, train a small student model to mimic the behaviour of a large teacher model. The student learns from both the hard labels (correct answers) and the teacher's soft probability distributions (which carry more information).

Teacher (BERT-large, 340M params)

Input: "This movie is great!"

Positive: 0.92
Negative: 0.08
↓ Soft labels (temperature=4)
Student (DistilBERT, 66M params)

Learns from teacher's distribution, not just the hard label

Result: 40% smaller, 60% faster, 97% of teacher accuracy

Python · Knowledge Distillation Training Loop
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, true_labels, T=4, alpha=0.5):
    # Soft labels loss (KL divergence between softened distributions)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T ** 2)

    # Hard labels loss (standard cross-entropy)
    hard_loss = F.cross_entropy(student_logits, true_labels)

    return alpha * soft_loss + (1 - alpha) * hard_loss

# Training
teacher.eval()  # Teacher is frozen
for batch in dataloader:
    with torch.no_grad():
        teacher_logits = teacher(batch)

    student_logits = student(batch)
    loss = distillation_loss(student_logits, teacher_logits, batch.labels)
    loss.backward()
    optimizer.step()

Real-World Example: DistilBERT

Hugging Face's DistilBERT is a publicly available distilled version of BERT. Trained with knowledge distillation from BERT-base, it achieves:

40%

Smaller (66M vs 110M params)

60%

Faster inference

97%

Of BERT accuracy retained

Python · Use DistilBERT
from transformers import pipeline

# Drop-in replacement for BERT — much faster!
classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
result = classifier("This new framework is fantastic!")
print(result)  # [{'label': 'POSITIVE', 'score': 0.9998}]

Choosing Between Techniques

TechniqueSpeedupAccuracyEffort
Quantisation (INT8)2–4×~99%Low (library call)
Unstructured Pruning1–2×~98%Medium
Structured Pruning2–5×~95%Medium-High
Distillation3–10×~95–98%High (retraining)

Frequently Asked Questions

Can I combine pruning and distillation?

Yes — this is called "progressive compression." First prune the teacher, then distil to a smaller student. Or simultaneously prune while distilling. Hugging Face's DistilBERT + quantisation is a popular stack that gives 5–8× total compression with ~97% accuracy.

What temperature should I use for distillation?

Temperature T in the range 2–10 is typical. Higher T softens the teacher's distribution, making it more informative for the student (the difference between 0.001% and 0.0001% probabilities carries signal about similarity). T=4 is a common default from the original Hinton et al. paper.

Frequently Asked Questions

What will I learn here?

This page covers the core concepts and techniques you need to understand the topic and progress confidently to the next lesson.

How should I use this page?

Start with the overview, then follow the section links to deepen your understanding. Use the table of contents on the right to jump to specific sections.

What should I read next?

Use the navigation below to continue to the next lesson or explore related topics.