Let’s take a look at an ASR model that blends CNNs with transformers for speech. Vanilla transformers are great for global context but are weak at capturing local details like phoneme edges and formants. CNNs are excellent for local structure, but they miss global dependencies. The Conformer, a "convolution-augmented transformer," has both.
Unlike models for text, an ASR model takes continuous audio as input. Because audio data has a high sampling rate, such as 16 kHz, we first need to extract acoustic features. Common choices are log-mel spectrogram or filterbank features. For one second of speech, you might get 100 frames with 80 dimensions each ($100 \times 80$). Just as an embedding layer projects 1D tokens into a vector space, in speech processing, there is a linear or convolutional layer that projects the acoustic features into the model's dimension, $d_{model}$. For example, this can result in an $80 \times 256$ matrix, so your token embeddings now have the dimensions $T \times d_{model}$ (e.g., $100 \times 256$).
A transformer encoder block looks like this:
x → [MHSA] → +residual → [FFN] → +residual → out
Conformer changes this to:
x → [FFN(0.5)] → [MHSA] → [Conv] → [FFN(0.5)] → out
The Conformer block calls these two feed-forward networks a "macaron-style feed-forward network," so named because the main attention and convolution modules are sandwiched between two FFNs, like the filling in a macaron. They add capacity and non-linearity but don’t mix information across time steps. They scale the FFN by half to match the contribution of one FFN in a vanilla transformer.
The first FFN applies a 2-layer MLP to each frame independently. The dimension might be expanded to $d_{ff}$ (e.g., 1024) with some activations, which is then scaled back so a half-scaled residual connection can be added to the input.
So far, the input is $B \times T \times d_{model}$; this is the output from the first FFN and its residual connection. Now, you project this input, X, into queries, keys, and values:
$$Q = XW_q^i, \quad K = XW_k^i, \quad V = XW_v^i$$Where the $W$ weight matrices are ${d_{model}}/{k} \times {d_{model}}/{k}$ for $k$ heads. You then compute attention:
$$\text{Attn}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$Concatenate the heads, project back with $W_o$, and add the residual:
$$Y = X + \text{MHSA}(X)$$This is almost the same as a vanilla transformer, but instead of sinusoidal encodings, it uses relative positions, which generalize better to long sequences and relative timing. With absolute positional encodings, for each input position, you would update the input as $x_t \leftarrow x_t + PE(t)$. The model is given the known positions, and the query and key themselves get shifted by this absolute index.
With relative positional encoding, we don’t alter the input embeddings directly. Instead, we introduce an extra bias term in the attention score based on the relative distance. The model learns biases based on the relative distances between tokens. Speech is shift-invariant—a phoneme like "ba" at 1 second and 10 seconds should be treated the same. What matters is how far apart the tokens are, not their absolute index.
$$\text{score}(i,j) = Q_i \cdot K_j + Q_i \cdot r_{(i-j)}$$Here, $r_{(i-j)}$ is a learned embedding for the relative offset between positions $i$ and $j$. We define a maximum relative distance, and this offset can be anywhere between $-L \dots 0 \dots L$, where each row corresponds to a $d_k$ vector space. For a given pair, we look up the relative embedding and perform a dot product with the query, which gets added to the attention scores. The embedding table has dimensions $(2L+1) \times d_k$. You can share this R matrix across all heads.
Next is the convolution module, which is the big difference from vanilla transformers. This is a depthwise separable convolution used to explicitly model local sequential structure.
The input is now the context-enriched sequence, wherein each frame has been mixed with others and long-range dependencies have been learned. The shape is still $B \times T \times d_{model}$. The convolution module takes this input and applies local convolutional patterns to learn short-term temporal patterns on top of the long-range dependencies already captured by the attention mechanism.
The convolution module looks similar to those in EfficientNet and MobileNet.
Let’s think: would you use LayerNorm or BatchNorm after this? LayerNorm would normalize per frame across all channels, so it could actually wash away the learned temporal filters. If you use BatchNorm, it normalizes per channel but across time and the batch. This way, each channel keeps its identity and normalization is consistent across time, which helps stabilize the convolution filters. So, BatchNorm makes more sense here, although it has challenges with running inference. BN would output zero-mean and unit-variance activations, so an activation function would help. Swish would let small negative values pass through, which can be crucial in audio processing.
Let’s briefly discuss BatchNorm during training and inference. During training, BatchNorm computes the mean and variance of the batch and time dimensions for each channel to normalize the values, along with learnable scale ($\gamma$) and shift ($\beta$) parameters for each channel. It additionally keeps a running estimate of both mean and variance, which is not used during the training forward pass but is stored to be used during inference later. During inference, you don’t use the batch statistics of the test data. You use the stored running stats along with the same learnable parameters to normalize the data:
$$x_{norm} = \gamma \cdot \frac{x - \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} + \beta$$Now, with streaming ASR, the issue is that during training, the stats were computed on long sequences and full batches of data. During streaming, you’re normalizing with global averages that don’t match the local chunk you’re currently processing. Additionally, in streaming, you don’t want to wait for future frames to compute stats, so you need a causal or online version of BatchNorm. There are variations of causal BN where the mean and variance are computed over frames $\le t$, so normalization only depends on the past. It’s incrementally updated as new frames arrive. There’s also a sliding-window BN that updates based only on the last W frames. In the beginning, however, these statistics can be unreliable, so they typically kick in once the state becomes reliable. Alternatively, you can use other normalizations that are more real-time friendly, like LayerNorm.
Anyway, the Conformer then applies a BatchNorm and a Swish activation. Finally, another pointwise convolution with dropout is applied, which performs a $d_{model} \rightarrow d_{model}$ mapping. Lastly, a residual connection is added from the start of the module to its output. This might seem redundant, but it introduces another cross-channel interaction which can be helpful for expressiveness, as depthwise convolution is a per-channel, local filter.
This is our convolution block. The output from the convolution block remains $B \times T \times d_{model}$. After this, another feed-forward network, similar to the one before, is applied to each frame with a residual connection to complete the macaron-style design. It expands $d_{model}$ to $d_{ff}$, applies a non-linearity, projects back to $d_{model}$ with dropout, and adds a scaled version of that as a residual. The shape remains $B \times T \times d_{model}$. This helps with stacking blocks, and we know that machine learning practitioners love stacking blocks!
The first FFN before the attention module enriches the frame-level features before they interacted globally. The second FFN enriches the representation after it has been mixed both globally and locally. This FFN symmetry is what is claimed to make the block more expressive and stable.
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------
# 1. Macaron-style FeedForward
# -------------------------------
class ConformerFFN(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.SiLU() # Swish
def forward(self, x):
out = self.linear1(x)
out = self.activation(out)
out = self.dropout(out)
out = self.linear2(out)
out = self.dropout(out)
return x + 0.5 * out # Macaron: scale by 0.5
# -------------------------------
# 2. Multi-Head Self-Attention (vanilla version, can extend with relative PE)
# -------------------------------
class MHSA(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.h = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, T, D = x.shape
Q = self.W_q(x).view(B, T, self.h, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.h, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.h, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return x + self.dropout(self.W_o(out)) # residual
# -------------------------------
# 3. Convolution Module
# -------------------------------
class ConformerConvModule(nn.Module):
def __init__(self, d_model, kernel_size=15, dropout=0.1, causal=False):
super().__init__()
self.layer_norm = nn.LayerNorm(d_model)
self.pointwise_conv1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
padding = (kernel_size - 1) if causal else (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
d_model, d_model, kernel_size,
groups=d_model, padding=padding
)
self.causal = causal
self.batch_norm = nn.BatchNorm1d(d_model)
self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.layer_norm(x)
x = x.transpose(1, 2) # (B, d_model, T)
x = self.pointwise_conv1(x)
A, B = x.chunk(2, dim=1)
x = A * torch.sigmoid(B) # GLU
x = self.depthwise_conv(x)
if self.causal:
x = x[:, :, :residual.size(1)] # trim if causal
x = self.batch_norm(x)
x = F.silu(x)
x = self.pointwise_conv2(x)
x = x.transpose(1, 2) # back to (B, T, d_model)
return residual + self.dropout(x) # residual
# -------------------------------
# 4. Full Conformer Block
# -------------------------------
class ConformerBlock(nn.Module):
def __init__(self, d_model=256, d_ff=1024, num_heads=4, kernel_size=15, dropout=0.1, causal=False):
super().__init__()
self.ffn1 = ConformerFFN(d_model, d_ff, dropout)
self.mhsa = MHSA(d_model, num_heads, dropout)
self.conv = ConformerConvModule(d_model, kernel_size, dropout, causal)
self.ffn2 = ConformerFFN(d_model, d_ff, dropout)
self.final_ln = nn.LayerNorm(d_model) # often added at block end
def forward(self, x, mask=None):
x = self.ffn1(x)
x = self.mhsa(x, mask)
x = self.conv(x)
x = self.ffn2(x)
return self.final_ln(x)
Now, how do you make the Conformer work in real-time? The attention mechanism, as we mentioned earlier, needs to change to a streaming attention where we restrict how far queries can look back using a fixed attention window with a causal mask. Sometimes speech can be ambiguous if you don’t peek a little into the future, so we can allow each frame to have a look-ahead by a fixed number of frames. This adds some latency but improves the model, and it can be easily configured in the attention mask. Additionally, you can make your multi-head attention into a multi-scale multi-head attention. Because information varies in length—a phoneme can be shorter, while syllables and words are longer—one head might be configured to span more frames than others to capture these different relationships.
The convolution module uses "same" padding, which will need to change to causal padding, and the normalization, as we discussed in depth, needs to change as well.
Okay, now we have our Conformer block, and we can stack several Conformer blocks to enrich the acoustic embeddings. Now we need to map the output to text. What we have are frame-level embeddings that need to be converted back to words or tokens.
You can add a linear layer after the stack of Conformer blocks that takes the $d_{model}$ dimension to $V$, the number of words, tokens, or phonemes in the vocabulary. The output is $Z$, a tensor of shape $B \times T \times V$. Then, take the model outputs (logits, which are unnormalized scores) and pass them through a softmax function for each frame to get output probabilities.
$$P(t,v) = \frac{\exp(Z_{t,v})}{\sum_{v'} \exp(Z_{t,v'})}$$Once you have this, you need to use a CTC loss to align the output probabilities with the text.