At their core, transformers are powered by a self-attention mechanism. Self-attention allows each token in a sequence to look at other tokens and decide what is important. For instance, when parsing "ramin writes the nicest tutorials," a model might need to pay more attention to "nicest" when processing "tutorials," depending on the task.
This is achieved using three vectors for each token: a Query (Q), a Key (K), and a Value (V).
These vectors are computed by multiplying the input matrix, $X$, by learnable weight matrices: $W_q$, $W_k$, and $W_v$.
Q = X @ W_q # Shape: (seq_len, d_k)
K = X @ W_k # Shape: (seq_len, d_k)
V = X @ W_v # Shape: (seq_len, d_v)
The dot product between a query vector $Q_i$ and a key vector $K_j$ produces a score indicating how much token i should attend to token j.
$$ \text{Score} = QK^T $$
These scores are then scaled to prevent large dot products, and a softmax function is applied to normalize them into attention weights that sum to 1.
$$ \text{Attention Score} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) $$
Finally, the output for each token is a weighted average of all other tokens' Value vectors, based on the calculated attention scores.
$$ \text{Output} = \text{Attention Score} \cdot V $$
For example, with the input "Cats Are Bossy":
Attention weights_1 = softmax(Q_1 @ [K1, K2, K3])Output_1 = Attention weights_1 @ [V1, V2, V3]import numpy as np
# ---- Step 1: Toy input (3 tokens, embed_dim = 4) ----
X = np.array([
[1.0, 0.0, 1.0, 0.0], # token 1
[0.0, 2.0, 0.0, 2.0], # token 2
[1.0, 1.0, 1.0, 1.0], # token 3
]) # shape: (3, 4)
seq_len, embed_dim = X.shape
d_k = d_v = 4 # We'll keep the same dim for simplicity
# ---- Step 2: Random weight matrices (learnable in real models) ----
W_q = np.random.randn(embed_dim, d_k)
W_k = np.random.randn(embed_dim, d_k)
W_v = np.random.randn(embed_dim, d_v)
# ---- Step 3: Compute Q, K, V ----
Q = X @ W_q # shape: (3, 4)
K = X @ W_k
V = X @ W_v
# ---- Step 4: Compute attention scores ----
scores = Q @ K.T / np.sqrt(d_k) # shape: (3, 3)
# ---- Step 5: Softmax over rows ----
def softmax(x):
# subtracting a constant to avoid overflow
x_exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
return x_exp / np.sum(x_exp, axis=-1, keepdims=True)
attn_weights = softmax(scores) # shape: (3, 3)
# ---- Step 6: Weighted sum over V ----
output = attn_weights @ V # shape: (3, 4)
A single self-attention mechanism, or "head," captures only one type of relationship at a time. Multi-Head Attention runs multiple attention heads in parallel, each with its own set of learned projections. This allows the model to attend to different aspects of the input simultaneously.
Instead of one set of weights, you have h heads, each computing its own output. These outputs are then concatenated and projected back to the original embedding dimension using an output weight matrix.
def multi_head_attention(X, Wq, Wk, Wv, Wo, h):
batch_size, seq_len, embed_dim = X.shape
d_k = embed_dim // h
# 1. Linear projections for all heads
Q = X @ Wq # shape: (batch, seq_len, embed_dim)
K = X @ Wk
V = X @ Wv
# 2. Split into h heads
# Reshape and transpose to bring the head dimension forward
Q = Q.reshape(batch_size, seq_len, h, d_k).transpose(0, 2, 1, 3)
K = K.reshape(batch_size, seq_len, h, d_k).transpose(0, 2, 1, 3)
V = V.reshape(batch_size, seq_len, h, d_k).transpose(0, 2, 1, 3)
# 3. Scaled dot-product attention for each head (vectorized)
scores = Q @ K.transpose(0, 1, 3, 2) / np.sqrt(d_k)
attn_weights = softmax(scores)
heads = attn_weights @ V # (batch, h, seq_len, d_k)
# 4. Concatenate heads
heads_concat = heads.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)
# 5. Final linear projection
output = heads_concat @ Wo # (batch, seq_len, embed_dim)
return output
Each transformer encoder block contains a multi-head attention layer followed by a position-wise Feed-Forward Network (FFN). Residual connections and layer normalization are applied after each of these two sub-layers.
x = LayerNorm(x + attn_out)FFN(x) = Linear(ReLU(Linear(x)))x = LayerNorm(x + ffn_out)The FFN adds capacity to the model by applying the same non-linear transformation to each token separately, acting like a token-wise filter.
def encoder_block(x, mha_params, ffn_params):
# Unpack params
Wq, Wk, Wv, Wo = mha_params
W1, b1, W2, b2 = ffn_params
# 1. Multi-head attention
attn_out = multi_head_attention(x, Wq, Wk, Wv, Wo, h=8)
# 2. Residual + LayerNorm
x = layer_norm(x + attn_out)
# 3. Feedforward
ffn_out = feedforward(x, W1, b1, W2, b2)
# 4. Residual + LayerNorm
x = layer_norm(x + ffn_out)
return x
Self-attention is permutation-invariant, meaning it doesn't inherently know the order of tokens. To address this, we inject positional information into the input embeddings using Positional Encoding. A popular method is Sinusoidal Positional Encoding, which uses sine and cosine functions.
$$ PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
$$ PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
This encoding is added to the input embeddings (x = x + PE). Since closer positions have similar encodings, the model can learn to compute relative positions from these patterns.
def positional_encoding(seq_len, d_model):
"""
Returns a (seq_len, d_model) positional encoding matrix.
"""
pos = np.arange(seq_len)[:, np.newaxis] # (T, 1)
i = np.arange(d_model)[np.newaxis, :] # (1, D)
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model) # (1, D)
angle_rads = pos * angle_rates # (T, D)
# Apply sin to even indices; cos to odd indices
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
return angle_rads # shape: (seq_len, d_model)
pe = positional_encoding(seq_len=10, d_model=16)
A typical encoder consists of multiple encoder blocks stacked on top of each other.
x = input + positional_encoding
for _ in range(N):
x = EncoderBlock(x)
The decoder's role is to generate the output sequence. It has a similar structure to the encoder but with two key differences:
Masking works by zeroing out the attention scores for all positions $j > i$ when calculating the attention for token $i$. This is often done by setting future scores to a very small number (like -1e9) before the softmax, effectively making their weights zero.
def masked_self_attention(x, W_q, W_k, W_v):
B, T, D = x.shape
d_k = D
Q = x @ W_q
K = x @ W_k
V = x @ W_v
scores = Q @ K.transpose(0, 2, 1) / np.sqrt(d_k)
# Causal mask
causal_mask = np.tril(np.ones((T, T))).astype(bool)
scores = np.where(causal_mask[None, :, :], scores, -1e9)
weights = softmax(scores) # axis=-1 is default
out = weights @ V # (B, T, D)
return out
Cross-attention is like standard attention, but the Query (Q) vectors come from the decoder, while the Key (K) and Value (V) vectors come from the encoder's output. This allows each decoder token to ask a "question" (Q) and find the most relevant information from the entire input sequence (K, V).
def cross_attention(x_dec, x_enc, W_q, W_k, W_v):
# x_dec: (B, T_dec, D)
# x_enc: (B, T_enc, D)
# Linear projections
Q = x_dec @ W_q # (B, T_dec, D)
K = x_enc @ W_k # (B, T_enc, D)
V = x_enc @ W_v # (B, T_enc, D)
# Compute attention scores
scores = Q @ np.transpose(K, (0, 2, 1)) / np.sqrt(K.shape[-1])
# Softmax over encoder tokens
weights = softmax(scores) # (B, T_dec, T_enc)
# Weighted sum
output = weights @ V # (B, T_dec, D)
return output, weights
A full decoder block combines these elements.
def decoder_block(x, enc_out, params):
# 1. Masked Self-Attention
x = layer_norm(x + masked_self_attention(x))
# 2. Cross-Attention (uses encoder output)
x = layer_norm(x + cross_attention(query=x, key=enc_out, value=enc_out))
# 3. Feedforward
x = layer_norm(x + feedforward(x))
return x
The full transformer connects the encoder and decoder.
x_enc (B, T_enc, D): The source sequence.x_dec (B, T_dec, D): The target sequence, shifted right.x_enc -> EncoderBlock -> enc_outx_dec + enc_out -> DecoderBlock -> outputThe final decoder output is passed through a linear layer and a softmax to produce logits, which are probability distributions over the vocabulary for predicting the next token.
class Transformer:
def __init__(self, num_layers, d_model, d_ff, vocab_size):
self.num_layers = num_layers
self.d_model = d_model
self.vocab_size = vocab_size
# Encoder and Decoder blocks
self.encoders = [EncoderBlock(d_model, d_ff) for _ in range(num_layers)]
self.decoders = [DecoderBlock(d_model, d_ff) for _ in range(num_layers)]
# Final projection to logits
self.W_out = np.random.randn(d_model, vocab_size)
self.b_out = np.zeros(vocab_size)
def forward(self, src, tgt):
# Encoder pass
x = src
for layer in self.encoders:
x = layer.forward(x)
enc_out = x
# Decoder pass
y = tgt
for layer in self.decoders:
y = layer.forward(y, enc_out)
# Final output
logits = y @ self.W_out + self.b_out
return logits
The training objective for building decoder-only language models is to maximize the probability of the next token given all previous tokens. During inference, the model autoregressively generates one token at a time by repeating the forward pass.
First, you need an input data vocabulary of tokens, for example, 10,000 words. Input is tokenized into token IDs using a tokenizer.
For instance, if the input is $[t_1, t_2, t_3]$ and the target is $[t_2, t_3, y_4]$, the model learns to predict each token from its context.
Tokens are embedded, and positional information is added since transformers are order-agnostic. The TokenPosEmbedding class embeds tokens and adds positional encodings:
class TokenPosEmbedding(nn.Module):
def __init__(self, vocab_size, embed_dim, max_len):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Embedding(max_len, embed_dim)
def forward(self, x):
B, T = x.shape
positions = torch.arange(T, device=x.device).unsqueeze(0) # [1, T]
x = self.token_embed(x) + self.pos_embed(positions) # [B, T, D]
return x
The decoder is causal, meaning it only attends to previous tokens, not future ones. This is achieved by applying a lower-triangular mask to the attention scores, setting future token attention to $-\infty$, which then becomes zero after softmax.
The causal_mask function creates this mask:
def causal_mask(T):
return torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(0) # [1, 1, T, T]
Here's the causal attention logic: Each token's output is a weighted sum of previous token values. The input $X$ (with shape $[B, T, D]$) is projected and chunked into query ($\text{Q}$), key ($\text{K}$), and value ($\text{V}$) matrices, each of shape $[B, T, D]$. $\text{Q}$, $\text{K}$, and $\text{V}$ are then reshaped for the number of heads to $[B, \text{h}, T, D_{\text{head}}]$. Raw attention scores are computed between every pair of positions, resulting in a shape of $[B, \text{h}, T, T]$. A lower triangular mask is then applied so that position $t$ cannot attend to any future tokens. Attention scores are normalized using softmax and applied to the values. Attention outputs are calculated by merging the heads and projecting back to the original embedding space. The model combines information from multiple heads, passes the results to the next layer, and maintains consistent dimensionality throughout the network, allowing seamless progression to the next transformer block.
class CausalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.num_heads = num_heads
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, T, D = x.shape
qkv = self.qkv(x) # [B, T, 3*D]
q, k, v = qkv.chunk(3, dim=-1)
# Split into heads: [B, heads, T, head_dim]
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # [B, heads, T, T]
# Mask out future tokens
mask = causal_mask(T).to(x.device)
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = attn_weights @ v # [B, heads, T, head_dim]
# Merge heads
out = attn_output.transpose(1, 2).contiguous().view(B, T, D)
return self.out_proj(out)
At each layer, the hidden state for token $t$ is a function of all preceding tokens. The transformer block includes a layer normalization, followed by a causal attention layer and a residual connection, another layer normalization, a Multi-Layer Perceptron (MLP), and a final residual connection. Layer normalization, which typically comes before attention in GPT models (some older transformer models used post-norm), helps with training stability. Residual connections ensure gradients flow, allowing the model to learn identity mappings or refine existing representations. The MLP is a two-layer per-token network that increases capacity, enabling richer token-wise transformations. Each token attends to prior tokens via causal attention.
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_dim):
super().__init__()
self.attn = CausalSelfAttention(embed_dim, num_heads)
self.ln1 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim),
nn.ReLU(),
nn.Linear(mlp_dim, embed_dim),
)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
x = x + self.attn(self.ln1(x)) # Residual connection
x = x + self.mlp(self.ln2(x)) # Residual connection
return x
You stack $N$ transformer decoder blocks to learn complex patterns over long-range token dependencies. The MiniGPT model is defined as:
class MiniGPT(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, mlp_dim, num_layers, max_len):
super().__init__()
self.embed = TokenPosEmbedding(vocab_size, embed_dim, max_len)
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_dim) for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
x = self.embed(x) # [B, T, D]
x = self.blocks(x) # [B, T, D]
x = self.ln_f(x) # [B, T, D]
return self.head(x) # [B, T, vocab_size]
We predict every next token using the cross-entropy loss. Tokens are generated one by one. For example, if the input is "the cat sat on," the target would be "cat sat on my." Our MiniGPT model learns word identity from token embeddings, word order from positional embeddings, long-range patterns via attention, and token-level transformations via MLPs.
model = MiniGPT(vocab_size=10000, embed_dim=256, num_heads=8, mlp_dim=1024, num_layers=4, max_len=128)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
model.train()
for epoch in range(10): # or however long you want to train
for x, y in train_loader:
# x: [B, T], y: [B, T]
logits = model(x) # [B, T, vocab]
loss = F.cross_entropy(logits.view(-1, 10000), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
During training, the model observes patterns in text, such as question/answer pairs, instructions, and dialogues. No weight changes occur during inference; the model utilizes what it has learned. During generation, we start with a prompt, such as "the cat." Logits for the next token are generated, focusing only on the prediction for the last token. We employ greedy sampling by selecting the most probable token from the distribution, appending it, and repeating the process. The idx tensor represents the current input tokens, with shape [batch_size, current_seq_len]. This sequence grows as new tokens are appended during generation. We crop idx because transformers have a maximum context length. This process repeats until the maximum number of new tokens is reached or an End-Of-Sequence (<EOS>) token is generated.
@torch.no_grad()
def generate(model, idx, max_new_tokens):
model.eval()
for _ in range(max_new_tokens):
# Crop input to max context size
idx_cond = idx[:, -model.embed.pos_embed.num_embeddings:]
# Forward pass
logits = model(idx_cond) # [B, T, vocab_size]
next_logits = logits[:, -1, :] # [B, vocab_size]
# Pick next token (greedy)
next_token = torch.argmax(next_logits, dim=-1, keepdim=True) # [B, 1]
# Append to current sequence
idx = torch.cat([idx, next_token], dim=1) # [B, T+1]
return idx
---
Decoder-only models do not require padding or masking if all sequences have the same length. For variable-length inference, no padding is used; tokens are simply grown one by one. In GPT and similar decoder-only transformers, input embeddings convert token IDs into vectors. The output projection layer maps the model's final hidden vector back to vocabulary logits. These two matrices, with shape [vocab_size, D], are often the same due to weight tying, which reduces the model's parameters. This encourages consistency between the input and output spaces, leading to faster convergence as the model doesn't need to learn that similar embeddings should be similarly decoded.
output_projection.weight = token_embedding.weight
GPT-2 uses scaled initialization in attention and output layers. This compensates for model depth and width to prevent exploding or vanishing activations and gradients.
Inference speed can be slow and expensive in decoder-only models. Generating one token at a time requires a full forward pass through the entire model for each token. This means recomputing attention over all previous tokens at every step. You take the entire input sequence, including all previously generated tokens, embed it, pass it through every transformer layer, and compute attention ($\text{Q}$, $\text{K}$, $\text{V}$). In standard attention, the complexity is $\mathcal{O}(T^2 \cdot D)$, where $T$ is the number of tokens seen so far and $D$ is the embedding size. Inference slows down with each new token generated. For example, generating 100 tokens from a prompt leads to quadratic growth in computation over time. Recomputing keys and values for all previous tokens in every forward pass results in slow generation, high memory usage, and latency spikes as the sequence grows.
Instead of recomputing everything from scratch, we can cache the keys and values from previous tokens and reuse them. For each layer and each head, k_prev and v_prev (shape [B, heads, T_prev, D_head]) are cached. Then, during inference, we only compute q_t for the current token and attend over:
$k_{all} = [k_{prev}, k_t]$
$v_{all} = [v_{prev}, v_t]$
This optimization reduces computation from quadratic to linear, resulting in nearly constant latency. You compute $\text{K}/\text{V}$ for the current token and store it.
At step 2: $q_2 \cdot [k_1, k_2]$
And at step 100: $q_{100} \cdot [k_1, \dots, k_{100}]$
You only compute $\text{Q}$ for the current token and reuse all past $\text{K}/\text{V}$ from previous steps. This ensures that latency per token is constant, not linear.
After a language model generates logits (raw scores) for the next token, various decoding strategies can be employed to select the actual token. The choice of strategy significantly impacts the quality, diversity, and coherence of the generated text.
Greedy decoding is the simplest approach, where the model always picks the token with the highest probability. This is a deterministic method: the same prompt will always produce the same output. While fast, it can lead to repetitive text and lacks diversity due to no exploration of alternative token paths.
next_token = torch.argmax(logits, dim=-1)
Multinomial sampling (also known as Temperature Sampling) involves randomly sampling a token proportional to its probability. This introduces randomness, allowing for more diverse outputs. The process involves applying a softmax function to the logits to get probabilities.
$$P_i = \frac{e^{(z_i/T)}}{\sum_j e^{(z_j/T)}}$$
Here, $z_i$ is the logit for token $i$, and $T$ is the temperature parameter.
For creative writing, a higher temperature might be considered, while for deterministic Q&A, a lower temperature is generally preferred.
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
Top-K sampling aims to balance quality and diversity by sampling only from the $k$ most likely tokens. It discards the vast majority of low-probability "garbage" tokens. The process involves:
def top_k_sample(logits, k=50, temperature=1.0):
logits = logits / temperature
top_k_vals, top_k_idx = torch.topk(logits, k)
probs = F.softmax(top_k_vals, dim=-1)
sample = torch.multinomial(probs, num_samples=1)
return top_k_idx.gather(1, sample) # map back to vocab indices
Top-P sampling (also known as Nucleus Sampling) dynamically selects the smallest set of tokens whose cumulative probability exceeds a predefined threshold $p$. This set is called the "nucleus." This method adapts to the model's uncertainty: if the model is confident, the nucleus will contain fewer tokens, while if it's uncertain, more tokens will be included.
def top_p_sample(logits, p=0.9, temperature=1.0):
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create mask for cumulative probability > p
mask = cumulative_probs > p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False # Always include top token
# Zero out the masked probs and re-normalize
sorted_probs[mask] = 0
sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
sample = torch.multinomial(sorted_probs, num_samples=1)
next_token = sorted_indices.gather(-1, sample)
return next_token
Beam search is a more sophisticated search algorithm that explores multiple sequences in parallel at each decoding step. Instead of selecting just one best token, it keeps track of the $N$ most likely sequences (where $N$ is the beam size).
Algorithm Overview (assuming Beam Size $B$, Vocabulary Size $V$, Output Length $T$):
<EOS>) token.While beam search can produce higher quality outputs by exploring more possibilities, it can sometimes favor shorter or less diverse (boring) outputs.
def beam_search(model, input_ids, beam_size=5, max_len=20):
beams = [(input_ids, 0.0)] # (sequence, score)
for _ in range(max_len):
all_candidates = []
for seq, score in beams:
# Get logits for the next token based on the current sequence
# Unsqueeze to add a batch dimension for model input, then get last token's logits
logits = model(seq.unsqueeze(0))[:, -1, :] # [1, vocab_size]
log_probs = F.log_softmax(logits, dim=-1).squeeze(0) # [vocab_size]
topk_logprobs, topk_ids = torch.topk(log_probs, beam_size)
for i in range(beam_size):
new_seq = torch.cat([seq, topk_ids[i].unsqueeze(0)], dim=0)
new_score = score + topk_logprobs[i].item()
all_candidates.append((new_seq, new_score))
# Keep top B overall candidates
beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_size]
# Note: For production use, you'd add logic here to check for EOS tokens
# and manage completed beams separately. For this example, we proceed for max_len.
return beams[0][0] # Best sequence