Smaller model means less memory bandwidth. Faster inference because fewer bits and fewer multiplies. Lower power because integer ops are cheaper than float on most hardware. You get the idea.
There are a couple types of quantization: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT). PTQ comes in two flavors: Dynamic and Static. Let’s break them down.
This is the lazy version of quantization—and that’s not a bad thing. Super simple. You quantize the weights after training. Activations stay in float32 during inference but get quantized on the fly. No calibration. No fuss.
import torch.quantization
model = YourModel()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
In this setup:
So why use it? It’s stupid easy and gives you a fast win on CPUs—especially for LSTMs and Linear-heavy models.
But why doesn’t it work well for CNNs? Because Conv2d is not matmul. Conv layers have gnarly memory access patterns and need pre-quantized activations to be efficient. Dynamic quantization doesn’t do that. You’d be quantizing every sliding window patch per inference. Not fun. Not fast.
Also, dynamic quantization doesn’t touch things like BatchNorm or ReLU. It just swaps in dynamic versions of Linear and LSTM layers. So, no fusion here. No point fusing layers if you're not touching them.
Now we’re doing it properly. Both weights and activations go to int8. But to quantize activations, you first need to know their range. That’s where calibration comes in. You run some input data through the model (no labels needed) to collect min and max values of activations. PyTorch uses that to decide how to quantize.
This works great for CNNs and edge deployment. You get full int8 inference.
The workflow looks like this:
class QuantCNN(nn.Module):
def fuse_model(self):
# Example: torch.quantization.fuse_modules(self.conv_relu_group, ['0', '1', '2'], inplace=True) # if conv, bn, relu are grouped
# Or fuse specific layers:
for module_name, module in self.named_children():
if "conv_bn_relu" in module_name: # A common pattern
torch.quantization.fuse_modules(module, ['conv', 'bn', 'relu'], inplace=True)
pass # Actual fusion logic here
model = QuantCNN() # Instantiate your model
model.eval()
model.fuse_model() # Apply fusion if defined in your model
# Set qconfig
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # or 'qnnpack' for ARM
# Insert observers
torch.quantization.prepare(model, inplace=True)
# Calibration step
with torch.no_grad():
for images in calibration_loader: # calibration_loader should yield batches of input data
model(images)
# Convert to int8
torch.quantization.convert(model, inplace=True)
Let’s talk about what’s going on.
Before quantizing, you fuse Conv + BatchNorm + ReLU into a single op. Why? It reduces memory overhead, speeds things up, and minimizes numerical errors.
QConfig tells PyTorch how to quantize: symmetric vs asymmetric, per-tensor vs per-channel. Example:
from torch.quantization import default_per_channel_weight_observer, MinMaxObserver
import torch # Ensure torch is imported for dtypes
my_qconfig = torch.quantization.QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8),
weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8) # Specify dtype for weights
)
model.qconfig = my_qconfig # Then assign it to your model
Observers are how PyTorch measures activation ranges. Here's a basic one (conceptual, PyTorch's ObserverBase is more complex):
from torch.ao.quantization.observer import ObserverBase
class MinMaxObserverExample(ObserverBase): # Conceptual
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.min_val = float('inf')
self.max_val = float('-inf')
def forward(self, x):
self.min_val = min(self.min_val, x.min().item())
self.max_val = max(self.max_val, x.max().item())
return x
During calibration, these observers record activation stats. Then PyTorch figures out the scale and zero-point needed to map float values to int8.
But MinMax isn’t always enough.
MinMax gets clobbered by outliers. Histogram observers are smarter. They build a histogram of the activations, simulate multiple clipping thresholds, and pick the one that gives the smallest quantization error.
import torch.quantization as tq
qconfig_hist = tq.QConfig(
activation=tq.observer.HistogramObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_affine
),
weight=tq.default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)
model.qconfig = qconfig_hist
They look at the L2 error between the original and the quantized-dequantized tensor and find the best range for preserving accuracy.
This is where quantization actually starts acting intelligent. Especially important when your activations have long tails, skew, or multimodal distributions.
When you quantize a float:
Q = round(X / scale) + zero_point
And to dequantize:
X = scale * (Q - zero_point)
zero_point = 0scale = max(abs(min_val), abs(max_val)) / 127 (for int8, typically range [-127, 127] or [-128, 127])Used for weights. Assumes distribution is centered around 0 (which is usually true for weights).
zero_point ≠ 0scale = (x_max - x_min) / 255 (for quint8, typically range [0, 255])zero_point = round(qmin - x_min / scale) (adjusted to be within [qmin, qmax])Used for activations. Better when data is skewed (e.g., ReLU output).
Bad calibration = bad quantization.
If your calibration dataset is too small or unrepresentative, the model sees weird ranges and either clips too much or wastes resolution. The more realistic your calibration inputs, the better your scale and zero-point.
Also, remember: PTQ is a one-shot deal. Once you calibrate and convert, you don’t get to fine-tune anymore. It’s set in stone.
Quantizing just part of a model? You’d better manage the float-int transitions carefully. These quant-dequant boundaries are expensive. You’re moving data from int8 to float and back. It kills performance and breaks layer fusion.
Once you quantize, stay quantized as long as possible. You can safely quantize ReLU, pooling, and many elementwise ops.
Use FX to trace and manipulate the model graphically.
from torch.fx import symbolic_trace
graph = symbolic_trace(quantized_model)
print(graph.graph) # For a textual representation
graph.print_tabular() # For a more readable table
FX makes it easier to:
It’s the modern way to do quantization in PyTorch. Forget the old Eager mode unless you like pain.
import torch # Ensure torch is imported for torch.quantization
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
FX takes care of applying qconfig recursively, inserting only the observers you need, and fusing layers correctly.
Let’s do an example with a CNN model to wrap up PTQ. We manually add PerChannelMinMaxObserver for weights and MinMaxObserver for activation.
class ObserverCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
# Manual observers
self.obs_act1 = MinMaxObserver(dtype=torch.quint8)
self.obs_act2 = MinMaxObserver(dtype=torch.quint8)
self.obs_fc_in = MinMaxObserver(dtype=torch.quint8)
self.obs_weight1 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
self.obs_weight2 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
self.obs_weight_fc = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8) # For Linear weights
def forward(self, x):
# Observe conv1 weight
self.obs_weight1(self.conv1.weight)
x = self.conv1(x) # Apply conv1
x = self.relu1(x) # Apply relu1
self.obs_act1(x) # Observe output of relu1
# Observe conv2 weight
self.obs_weight2(self.conv2.weight)
x = self.conv2(x) # Apply conv2
x = self.relu2(x) # Apply relu2
self.obs_act2(x) # Observe output of relu2
x = self.pool(x)
x = x.view(x.size(0), -1)
self.obs_weight_fc(self.fc.weight) # Observe fc weight
self.obs_fc_in(x) # Observe input to fc layer
x = self.fc(x) # Apply fc
return x
model_ptq_example = ObserverCNN().eval()
# Fake calibration pass
with torch.no_grad():
for _ in range(10): # simulate 10 batches
dummy_input = torch.randn(8, 3, 32, 32)
_ = model_ptq_example(dummy_input)
print("Conv1 activation range:", model_ptq_example.obs_act1.min_val.item(), model_ptq_example.obs_act1.max_val.item())
print("Conv2 activation range:", model_ptq_example.obs_act2.min_val.item(), model_ptq_example.obs_act2.max_val.item())
print("FC input range:", model_ptq_example.obs_fc_in.min_val.item(), model_ptq_example.obs_fc_in.max_val.item())
print("Conv1 weight qparams:", model_ptq_example.obs_weight1.calculate_qparams())
print("Conv2 weight qparams:", model_ptq_example.obs_weight2.calculate_qparams())
print("FC weight qparams:", model_ptq_example.obs_weight_fc.calculate_qparams())
calculate_qparams() is a method used by observers in PyTorch to compute scale and zero-point. That's the observer's job to begin with.
General quantization formula: quantized = round(clamp(x / scale + zero_point, qmin, qmax))
scale = max(abs(min_val), abs(max_val)) / 127 (or 128)zero_point = 0scale = (max_val - min_val) / (qmax - qmin)zero_point = round(qmin - min_val / scale)Coming up next: Quantization Aware Training, where we teach the model to embrace the pain and learn to live with it.
If PTQ drops too much accuracy, QAT is how you can recover it. QAT requires retraining; it’s slower but more accurate. It learns to suppress outliers and it works well with tuning. You start with inserting observers (which are actually FakeQuantize modules in QAT) just like PTQ. These FakeQuantize modules are differentiable. The model learns to adapt to quantization noise.
from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
qat_qconfig = get_default_qat_qconfig('fbgemm')
qconfig_dict_qat = {"": qat_qconfig}
model_for_qat = ObserverCNN() # Assuming ObserverCNN or similar model
model_for_qat.train() # Set model to training mode for QAT
# Note: prepare_qat_fx typically requires the model to be in training mode.
model_prepared_qat = prepare_qat_fx(model_for_qat, qconfig_dict_qat)
Then you train the model_prepared_qat and convert it to int8 using convert_fx(model_prepared_qat.eval()). Inside the FakeQuantize module, it takes x, computes scale and zero-point, and simulates quantization like:
q = round(clamp(x / scale + zero_point))
x_hat = (q - zero_point) * scale
QAT can learn activation clipping, push small weights out of dead zones, and learn quantization-friendly distributions.
QAT simulates quantization in the forward pass. Rounding and clamping are, however, both non-differentiable, so you can't directly backpropagate through them. QAT uses a Straight-Through Estimator (STE), which basically means this operation has no gradient for the non-differentiable part. It pretends the output equals the input during backpropagation for that specific part.
class RoundSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output # straight-through (acts like identity for the rounding part)
STE allows the model to learn parameters and "survive" quantization during training.
Now let’s go back to MovingAverageMinMaxObserver (often used within FakeQuantize for QAT). It uses an exponential moving average of min/max values observed during training. FakeQuantize modules are inserted during prepare_qat_fx() and are active during the training phase. After training, during convert_fx(), they’re replaced with real quantization operations.
The moving average update rules are typically:
min_val = (1 - α) * min_val + α * current_batch_min_val
max_val = (1 - α) * max_val + α * current_batch_max_val
Where α (alpha) is a momentum term.
Training a QAT model often requires lower learning rates and more epochs to converge. It also needs careful initialization, for example, starting from a pretrained FP32 model. Similar to static quantization, layers such as BatchNorm (BN) are fused with preceding convolution or linear layers during the conversion step (convert_fx).
After inserting FakeQuantize modules and monitoring ranges, it helps to fine-tune the model to the "frozen" quantization noise. This means that after an initial "warm-up" phase where quantization parameters (scale and zero-point) are being learned, these parameters are then fixed, and the model continues training to adapt to these specific, now constant, quantization effects.
Inside the FakeQuantize modules, there is an observer instance (or logic) and a quantize-dequantize function. You can inspect them. The activation_post_process attribute on a layer (e.g., model.layer_name.activation_post_process) *is* the FakeQuantize module itself after prepare_qat_fx.
if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
fq_module = model_prepared_qat.conv1.activation_post_process # This IS the FakeQuantize module
current_scale, current_zero_point = fq_module.calculate_qparams()
print(f"Scale: {current_scale}, Zero-point: {current_zero_point}")
if hasattr(fq_module, 'min_val') and hasattr(fq_module, 'max_val'):
print(f"Observed min by FakeQuant: {fq_module.min_val.item()}, Observed max: {fq_module.max_val.item()}")
else:
print("conv1 or activation_post_process not found in conceptual model_prepared_qat.")
During the forward pass, observers (within FakeQuantize) are enabled to calculate the scale and zero-point. The fake quantizer then applies the quantize-dequantize simulation and uses STE during backpropagation. During QAT training (warm-up phase), observers are active, updating the range statistics, and the model learns under these evolving quantization conditions. Once the warm-up is done, the scale and zero-point are often "frozen." Training continues with this consistent quantization noise to improve stability.
Common issues include:
convert_fx() fails - this happens a lot!
for name, module in model_prepared_qat.named_modules():
if isinstance(module, torch.ao.quantization.FakeQuantize):
# For FakeQuantize modules, scale and zero_point are attributes
print(f"{name} -- scale: {module.scale.item() if torch.is_tensor(module.scale) else module.scale}, "
f"zero_point: {module.zero_point.item() if torch.is_tensor(module.zero_point) else module.zero_point}")
elif hasattr(module, 'weight_fake_quant'): # For layers with separate weight fake_quant
print(f"{name}.weight_fake_quant -- scale: {module.weight_fake_quant.scale.item()}, "
f"zero_point: {module.weight_fake_quant.zero_point.item()}")
If you see scale is zero or NaN, then the observers aren’t seeing anything. If the scale is too large or too small, the range is skewed due to outliers.
It’s important to see what range of values a layer is producing. Make sure the quantization observers capture the range properly and avoid outliers. If the range is off you’ll get clipping, dead activation, and poor resolution due to wide range. Hook into a layer to see float ranges and compare it with what you see from the observer reports.
activation_histograms = {}
def capture_activations_hook(name):
def hook(module, input_val, output_val):
data_to_log = output_val
if isinstance(output_val, tuple): data_to_log = output_val[0] # Handle tuple outputs
if torch.is_tensor(data_to_log):
activation_histograms[name] = data_to_log.detach().cpu().flatten().numpy()
return hook
model_to_debug = ObserverCNN().eval() # Or your specific model, ensure ObserverCNN is defined
hook_conv1 = model_to_debug.conv1.register_forward_hook(capture_activations_hook("conv1_output"))
hook_relu1 = model_to_debug.relu1.register_forward_hook(capture_activations_hook("relu1_output"))
with torch.no_grad():
_ = model_to_debug(torch.randn(1, 3, 32, 32)) # Adjust input shape as needed
for name, data in activation_histograms.items():
plt.figure()
plt.hist(data, bins=100)
plt.title(f"Activation Histogram: {name}")
plt.xlabel("Value"); plt.ylabel("Frequency"); plt.grid(True)
plt.show()
hook_conv1.remove() # Clean up hooks
hook_relu1.remove()
Compare float vs quantized outputs. If a single layer is the issue, that narrows down the layer-wise outputs.
if hasattr(model_prepared_qat, 'conv1') and \
hasattr(model_prepared_qat.conv1, 'activation_post_process') and \
hasattr(model_prepared_qat.conv1.activation_post_process, 'disable_fake_quant'):
model_prepared_qat.conv1.activation_post_process.disable_fake_quant()
# Re-evaluate and then re-enable with .enable_fake_quant() or re-prepare the model.
print_tabular()This shows you which quantization observers are attached and where scale and zero points are applied. Look for missing activation_post_process or quant/dequant pairs due to bad layer fusion.
if hasattr(model_prepared_qat, 'graph'):
model_prepared_qat.graph.print_tabular()
else:
print("Model does not seem to be an FX graph module or 'graph' attribute is missing.")
activation_post_processactivation_post_process is the name of a hook module (FakeQuantize in QAT) that PyTorch uses to handle quantization of activation. It’s inserted when you call prepare_qat_fx().
# Example:
self.conv = nn.Conv2d(...)
self.conv.activation_post_process = FakeQuantize(...)
if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
fq_module_inspect = model_prepared_qat.conv1.activation_post_process
print(fq_module_inspect.scale, fq_module_inspect.zero_point)
if hasattr(fq_module_inspect, 'min_val'): # Check if min_val/max_val are directly on FakeQuant
print("Observed min/max by FakeQuant:", fq_module_inspect.min_val.item(), fq_module_inspect.max_val.item())
If the values are outside min/max, they’ll be clipped during quantization.
torch.ao.quantization.quant_debugThis helps you record and compare float vs quantized activation, layer by layer differences, and histograms of differences. (API might vary across PyTorch versions).
debug_model = add_debug_observers_qat(model_prepared_qat, qconfig_dict_qat)
print(debug_model.activation_debug_stats) # Or similar attribute, API might change
tensor = torch.randn(1000) # Example tensor
plt.hist(tensor.detach().cpu().flatten().numpy(), bins=100)
plt.title("Tensor Value Histogram")
plt.show()
MovingAverageMinMaxObserver directly in PTQ (it's primarily for QAT's FakeQuantize modules); PTQ typically uses observers like MinMaxObserver or HistogramObserver for one-shot calibration.