Quantization

Smaller model means less memory bandwidth. Faster inference because fewer bits and fewer multiplies. Lower power because integer ops are cheaper than float on most hardware. You get the idea.

There are a couple types of quantization: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT). PTQ comes in two flavors: Dynamic and Static. Let’s break them down.

Dynamic Quantization

This is the lazy version of quantization—and that’s not a bad thing. Super simple. You quantize the weights after training. Activations stay in float32 during inference but get quantized on the fly. No calibration. No fuss.


import torch.quantization
model = YourModel()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
     

In this setup:

So why use it? It’s stupid easy and gives you a fast win on CPUs—especially for LSTMs and Linear-heavy models.

But why doesn’t it work well for CNNs? Because Conv2d is not matmul. Conv layers have gnarly memory access patterns and need pre-quantized activations to be efficient. Dynamic quantization doesn’t do that. You’d be quantizing every sliding window patch per inference. Not fun. Not fast.

Also, dynamic quantization doesn’t touch things like BatchNorm or ReLU. It just swaps in dynamic versions of Linear and LSTM layers. So, no fusion here. No point fusing layers if you're not touching them.

Static Quantization

Now we’re doing it properly. Both weights and activations go to int8. But to quantize activations, you first need to know their range. That’s where calibration comes in. You run some input data through the model (no labels needed) to collect min and max values of activations. PyTorch uses that to decide how to quantize.

This works great for CNNs and edge deployment. You get full int8 inference.

The workflow looks like this:


 class QuantCNN(nn.Module):
     def fuse_model(self):
         # Example: torch.quantization.fuse_modules(self.conv_relu_group, ['0', '1', '2'], inplace=True) # if conv, bn, relu are grouped
         # Or fuse specific layers:
          for module_name, module in self.named_children():
              if "conv_bn_relu" in module_name: # A common pattern
                 torch.quantization.fuse_modules(module, ['conv', 'bn', 'relu'], inplace=True)
         pass # Actual fusion logic here

model = QuantCNN() # Instantiate your model
model.eval()
model.fuse_model() # Apply fusion if defined in your model

# Set qconfig
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # or 'qnnpack' for ARM

# Insert observers
 torch.quantization.prepare(model, inplace=True)

# Calibration step
 with torch.no_grad():
     for images in calibration_loader: # calibration_loader should yield batches of input data
         model(images)

# Convert to int8
torch.quantization.convert(model, inplace=True)
     

Let’s talk about what’s going on.

Fusion

Before quantizing, you fuse Conv + BatchNorm + ReLU into a single op. Why? It reduces memory overhead, speeds things up, and minimizes numerical errors.

QConfig

QConfig tells PyTorch how to quantize: symmetric vs asymmetric, per-tensor vs per-channel. Example:


from torch.quantization import default_per_channel_weight_observer, MinMaxObserver
import torch # Ensure torch is imported for dtypes

my_qconfig = torch.quantization.QConfig(
    activation=MinMaxObserver.with_args(dtype=torch.quint8),
    weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8) # Specify dtype for weights
)
model.qconfig = my_qconfig # Then assign it to your model
     

Observers

Observers are how PyTorch measures activation ranges. Here's a basic one (conceptual, PyTorch's ObserverBase is more complex):


from torch.ao.quantization.observer import ObserverBase
class MinMaxObserverExample(ObserverBase): # Conceptual
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.min_val = float('inf')
        self.max_val = float('-inf')

    def forward(self, x):
        self.min_val = min(self.min_val, x.min().item())
        self.max_val = max(self.max_val, x.max().item())
        return x
     

During calibration, these observers record activation stats. Then PyTorch figures out the scale and zero-point needed to map float values to int8.

But MinMax isn’t always enough.

Histogram Observers

MinMax gets clobbered by outliers. Histogram observers are smarter. They build a histogram of the activations, simulate multiple clipping thresholds, and pick the one that gives the smallest quantization error.


import torch.quantization as tq 
qconfig_hist = tq.QConfig(
    activation=tq.observer.HistogramObserver.with_args(
        dtype=torch.quint8, qscheme=torch.per_tensor_affine
    ),
    weight=tq.default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)
model.qconfig = qconfig_hist
     

They look at the L2 error between the original and the quantized-dequantized tensor and find the best range for preserving accuracy.

This is where quantization actually starts acting intelligent. Especially important when your activations have long tails, skew, or multimodal distributions.

Quick Refresher: Scale and Zero-Point

When you quantize a float:

Q = round(X / scale) + zero_point

And to dequantize:

X = scale * (Q - zero_point)

Symmetric

Used for weights. Assumes distribution is centered around 0 (which is usually true for weights).

Asymmetric

Used for activations. Better when data is skewed (e.g., ReLU output).

Calibration Matters

Bad calibration = bad quantization.

If your calibration dataset is too small or unrepresentative, the model sees weird ranges and either clips too much or wastes resolution. The more realistic your calibration inputs, the better your scale and zero-point.

Also, remember: PTQ is a one-shot deal. Once you calibrate and convert, you don’t get to fine-tune anymore. It’s set in stone.

Be Smart About Quant Boundaries

Quantizing just part of a model? You’d better manage the float-int transitions carefully. These quant-dequant boundaries are expensive. You’re moving data from int8 to float and back. It kills performance and breaks layer fusion.

Once you quantize, stay quantized as long as possible. You can safely quantize ReLU, pooling, and many elementwise ops.

Visualize with FX

Use FX to trace and manipulate the model graphically.


from torch.fx import symbolic_trace
graph = symbolic_trace(quantized_model)
print(graph.graph) # For a textual representation
graph.print_tabular() # For a more readable table
     

FX makes it easier to:

It’s the modern way to do quantization in PyTorch. Forget the old Eager mode unless you like pain.


import torch # Ensure torch is imported for torch.quantization
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
     

FX takes care of applying qconfig recursively, inserting only the observers you need, and fusing layers correctly.

PTQ Example: CNN with Manual Observers

Let’s do an example with a CNN model to wrap up PTQ. We manually add PerChannelMinMaxObserver for weights and MinMaxObserver for activation.


class ObserverCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)

        # Manual observers
        self.obs_act1 = MinMaxObserver(dtype=torch.quint8)
        self.obs_act2 = MinMaxObserver(dtype=torch.quint8)
        self.obs_fc_in = MinMaxObserver(dtype=torch.quint8)

        self.obs_weight1 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
        self.obs_weight2 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
        self.obs_weight_fc = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8) # For Linear weights

    def forward(self, x):
        # Observe conv1 weight
        self.obs_weight1(self.conv1.weight)
        x = self.conv1(x) # Apply conv1
        x = self.relu1(x) # Apply relu1
        self.obs_act1(x)  # Observe output of relu1

        # Observe conv2 weight
        self.obs_weight2(self.conv2.weight)
        x = self.conv2(x) # Apply conv2
        x = self.relu2(x) # Apply relu2
        self.obs_act2(x)  # Observe output of relu2
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)

        self.obs_weight_fc(self.fc.weight) # Observe fc weight
        self.obs_fc_in(x) # Observe input to fc layer
        x = self.fc(x) # Apply fc
        return x

model_ptq_example = ObserverCNN().eval()

# Fake calibration pass
with torch.no_grad():
    for _ in range(10):  # simulate 10 batches
        dummy_input = torch.randn(8, 3, 32, 32)
        _ = model_ptq_example(dummy_input)

print("Conv1 activation range:", model_ptq_example.obs_act1.min_val.item(), model_ptq_example.obs_act1.max_val.item())
print("Conv2 activation range:", model_ptq_example.obs_act2.min_val.item(), model_ptq_example.obs_act2.max_val.item())
print("FC input range:", model_ptq_example.obs_fc_in.min_val.item(), model_ptq_example.obs_fc_in.max_val.item())

print("Conv1 weight qparams:", model_ptq_example.obs_weight1.calculate_qparams())
print("Conv2 weight qparams:", model_ptq_example.obs_weight2.calculate_qparams())
print("FC weight qparams:", model_ptq_example.obs_weight_fc.calculate_qparams())
        

calculate_qparams() is a method used by observers in PyTorch to compute scale and zero-point. That's the observer's job to begin with.

General quantization formula: quantized = round(clamp(x / scale + zero_point, qmin, qmax))

For Symmetric Quantization:

For Asymmetric Quantization:

Coming up next: Quantization Aware Training, where we teach the model to embrace the pain and learn to live with it.

Quantization-Aware Training (QAT)

If PTQ drops too much accuracy, QAT is how you can recover it. QAT requires retraining; it’s slower but more accurate. It learns to suppress outliers and it works well with tuning. You start with inserting observers (which are actually FakeQuantize modules in QAT) just like PTQ. These FakeQuantize modules are differentiable. The model learns to adapt to quantization noise.


from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
qat_qconfig = get_default_qat_qconfig('fbgemm')
qconfig_dict_qat = {"": qat_qconfig}

model_for_qat = ObserverCNN() # Assuming ObserverCNN or similar model
model_for_qat.train() # Set model to training mode for QAT
# Note: prepare_qat_fx typically requires the model to be in training mode.
model_prepared_qat = prepare_qat_fx(model_for_qat, qconfig_dict_qat)
        

Then you train the model_prepared_qat and convert it to int8 using convert_fx(model_prepared_qat.eval()). Inside the FakeQuantize module, it takes x, computes scale and zero-point, and simulates quantization like:


q = round(clamp(x / scale + zero_point))
x_hat = (q - zero_point) * scale
        

QAT can learn activation clipping, push small weights out of dead zones, and learn quantization-friendly distributions.

QAT simulates quantization in the forward pass. Rounding and clamping are, however, both non-differentiable, so you can't directly backpropagate through them. QAT uses a Straight-Through Estimator (STE), which basically means this operation has no gradient for the non-differentiable part. It pretends the output equals the input during backpropagation for that specific part.


class RoundSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.round(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output  # straight-through (acts like identity for the rounding part)
        

STE allows the model to learn parameters and "survive" quantization during training.

Now let’s go back to MovingAverageMinMaxObserver (often used within FakeQuantize for QAT). It uses an exponential moving average of min/max values observed during training. FakeQuantize modules are inserted during prepare_qat_fx() and are active during the training phase. After training, during convert_fx(), they’re replaced with real quantization operations.

The moving average update rules are typically:


min_val = (1 - α) * min_val + α * current_batch_min_val
max_val = (1 - α) * max_val + α * current_batch_max_val
        

Where α (alpha) is a momentum term.

QAT Training & FakeQuantize Details

Training a QAT model often requires lower learning rates and more epochs to converge. It also needs careful initialization, for example, starting from a pretrained FP32 model. Similar to static quantization, layers such as BatchNorm (BN) are fused with preceding convolution or linear layers during the conversion step (convert_fx).

After inserting FakeQuantize modules and monitoring ranges, it helps to fine-tune the model to the "frozen" quantization noise. This means that after an initial "warm-up" phase where quantization parameters (scale and zero-point) are being learned, these parameters are then fixed, and the model continues training to adapt to these specific, now constant, quantization effects.

A note on observers and FakeQuantize modules

Inside the FakeQuantize modules, there is an observer instance (or logic) and a quantize-dequantize function. You can inspect them. The activation_post_process attribute on a layer (e.g., model.layer_name.activation_post_process) *is* the FakeQuantize module itself after prepare_qat_fx.


if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
    fq_module = model_prepared_qat.conv1.activation_post_process # This IS the FakeQuantize module
    
    current_scale, current_zero_point = fq_module.calculate_qparams()
    print(f"Scale: {current_scale}, Zero-point: {current_zero_point}")

    if hasattr(fq_module, 'min_val') and hasattr(fq_module, 'max_val'):
         print(f"Observed min by FakeQuant: {fq_module.min_val.item()}, Observed max: {fq_module.max_val.item()}")
else:
    print("conv1 or activation_post_process not found in conceptual model_prepared_qat.")
        

During the forward pass, observers (within FakeQuantize) are enabled to calculate the scale and zero-point. The fake quantizer then applies the quantize-dequantize simulation and uses STE during backpropagation. During QAT training (warm-up phase), observers are active, updating the range statistics, and the model learns under these evolving quantization conditions. Once the warm-up is done, the scale and zero-point are often "frozen." Training continues with this consistent quantization noise to improve stability.

Quantization Bugs

Common issues include:

How to Trace the Bugs?

Print quantization parameters


for name, module in model_prepared_qat.named_modules():
    if isinstance(module, torch.ao.quantization.FakeQuantize):
        # For FakeQuantize modules, scale and zero_point are attributes
        print(f"{name} -- scale: {module.scale.item() if torch.is_tensor(module.scale) else module.scale}, "
              f"zero_point: {module.zero_point.item() if torch.is_tensor(module.zero_point) else module.zero_point}")
    elif hasattr(module, 'weight_fake_quant'): # For layers with separate weight fake_quant
         print(f"{name}.weight_fake_quant -- scale: {module.weight_fake_quant.scale.item()}, "
               f"zero_point: {module.weight_fake_quant.zero_point.item()}")
        

If you see scale is zero or NaN, then the observers aren’t seeing anything. If the scale is too large or too small, the range is skewed due to outliers.

Visualize activation ranges

It’s important to see what range of values a layer is producing. Make sure the quantization observers capture the range properly and avoid outliers. If the range is off you’ll get clipping, dead activation, and poor resolution due to wide range. Hook into a layer to see float ranges and compare it with what you see from the observer reports.


activation_histograms = {}
def capture_activations_hook(name):
    def hook(module, input_val, output_val):
        data_to_log = output_val
        if isinstance(output_val, tuple): data_to_log = output_val[0] # Handle tuple outputs
        if torch.is_tensor(data_to_log):
            activation_histograms[name] = data_to_log.detach().cpu().flatten().numpy()
    return hook

 model_to_debug = ObserverCNN().eval() # Or your specific model, ensure ObserverCNN is defined
 hook_conv1 = model_to_debug.conv1.register_forward_hook(capture_activations_hook("conv1_output"))
 hook_relu1 = model_to_debug.relu1.register_forward_hook(capture_activations_hook("relu1_output"))

 with torch.no_grad():
     _ = model_to_debug(torch.randn(1, 3, 32, 32)) # Adjust input shape as needed

 for name, data in activation_histograms.items():
     plt.figure()
     plt.hist(data, bins=100)
     plt.title(f"Activation Histogram: {name}")
     plt.xlabel("Value"); plt.ylabel("Frequency"); plt.grid(True)
     plt.show()

hook_conv1.remove() # Clean up hooks
hook_relu1.remove()
        

Compare float vs quantized outputs. If a single layer is the issue, that narrows down the layer-wise outputs.

Disable quantization from layers one by one


if hasattr(model_prepared_qat, 'conv1') and \
   hasattr(model_prepared_qat.conv1, 'activation_post_process') and \
   hasattr(model_prepared_qat.conv1.activation_post_process, 'disable_fake_quant'):
    model_prepared_qat.conv1.activation_post_process.disable_fake_quant()
# Re-evaluate and then re-enable with .enable_fake_quant() or re-prepare the model.
        

Run FX print_tabular()

This shows you which quantization observers are attached and where scale and zero points are applied. Look for missing activation_post_process or quant/dequant pairs due to bad layer fusion.


if hasattr(model_prepared_qat, 'graph'):
    model_prepared_qat.graph.print_tabular()
else:
    print("Model does not seem to be an FX graph module or 'graph' attribute is missing.")
        

Inspect activation_post_process

activation_post_process is the name of a hook module (FakeQuantize in QAT) that PyTorch uses to handle quantization of activation. It’s inserted when you call prepare_qat_fx().


# Example:
self.conv = nn.Conv2d(...)
self.conv.activation_post_process = FakeQuantize(...)
if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
    fq_module_inspect = model_prepared_qat.conv1.activation_post_process
    print(fq_module_inspect.scale, fq_module_inspect.zero_point)
    if hasattr(fq_module_inspect, 'min_val'): # Check if min_val/max_val are directly on FakeQuant
        print("Observed min/max by FakeQuant:", fq_module_inspect.min_val.item(), fq_module_inspect.max_val.item())
        

If the values are outside min/max, they’ll be clipped during quantization.

Use torch.ao.quantization.quant_debug

This helps you record and compare float vs quantized activation, layer by layer differences, and histograms of differences. (API might vary across PyTorch versions).


debug_model = add_debug_observers_qat(model_prepared_qat, qconfig_dict_qat)
print(debug_model.activation_debug_stats) # Or similar attribute, API might change
        

Look for outliers and dead activations


tensor = torch.randn(1000) # Example tensor
plt.hist(tensor.detach().cpu().flatten().numpy(), bins=100)
plt.title("Tensor Value Histogram")
plt.show()
        

Wrapping Up