We all know what feedforward networks are. They assume a fixed-size input. It takes an input, multiplies it with some weights, and applies non-linearities. The order of the input doesn’t matter. This is a problem for input sequences, like words in a sentence, stock prices over days, or audio samples, where order actually matters. Different sentences can have different lengths, which would require padding or truncating. A feedforward network collapses the learned weights, losing the temporal structure.
With recurrence, instead of processing all input at once, you process one input at a time and carry forward some memory.
Hidden State:
$$h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)$$Output:
$$y_t = W_{hy} h_t + b_y$$The current hidden state depends on the previous hidden state. That’s what gives it memory.
Let’s take it one step at a time. At each timestep $t$:
Unlike backpropagation that flows layer by layer, in an RNN we have time as well as depth. This is called backpropagation through time (BPTT).
Say the loss is cross-entropy: $L = \sum_{t=1}^{T} L_t(y_t, \hat{y}_t)$
When we unroll the RNN for $T$ timesteps, it looks like a deep forward network of depth $T$, where each layer reuses the same weights ($W_{xh}, W_{hh}, W_{hy}$). During backprop, gradients flow through the sequence of hidden states. Since parameters are shared, their gradients accumulate across timesteps.
The trouble comes from the recurrent connection. Since $h_t$ depends on $h_{t-1}$, this expands into a long chain rule across time.
tanh, which are $\le 1$, they shrink exponentially. The model forgets long-term dependencies.You can deal with exploding gradients using gradient clipping. This shrinks the gradient vector before updating the parameters. Typically, you take the L2 norm of all parameters flattened into one big vector.
Let $g = [\text{grad}(W_{xh}), \text{grad}(W_{hh}), \text{grad}(W_{hy}), \dots]$. The L2 norm is $\|g\|_2 = \sqrt{\sum_i g_i^2}$.
If this value is bigger than a threshold $\tau$, you rescale the entire gradient vector:
$$ \text{If } \|g\|_2 > \tau: \quad g \leftarrow g \cdot \frac{\tau}{\|g\|_2} $$This keeps the direction of the gradient but reduces its magnitude.
Now, the main issue with RNNs was really vanishing gradients, which affects long-term memory. The old memory gets mixed in and there’s no control over how information passes to the future; the information washes out. The RNN forgets context.
The GRU introduces gates to control what information to keep and what to forget.
For example, consider “Ramin wrote the best tutorial, he did it again!”. A vanilla RNN might overwrite “Ramin” and forget who “he” was. A GRU's update gate would say "don’t overwrite Ramin, hold on to that memory." If the update gate decides to just keep the old memory, the hidden state can flow unchanged for many steps. That means gradients can also flow back easily, solving the vanishing gradients problem.
Conceptually, the hidden state update is:
$$h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$
Where $z_t$ is the update gate. If $z_t=0$, then $h_t = h_{t-1}$. The memory is copied directly from one step to the next. No multiplication, no tanh—just an identity connection. During backprop, the gradients flow through this same path, unchanged. This is a shortcut path or linear highway for memory.
There are 2 gates, both are little feedforward networks with a sigmoid activation, so their outputs are between 0 and 1.
Finally, the output is computed from the hidden state:
$$y_t = W_{hy} h_t + b_y$$
import torch
import torch.nn as nn
import torch.nn.functional as F
class GRUCell(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(GRUCell, self).__init__()
self.hidden_dim = hidden_dim
# Update gate parameters
self.W_z = nn.Linear(input_dim, hidden_dim)
self.U_z = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Reset gate parameters
self.W_r = nn.Linear(input_dim, hidden_dim)
self.U_r = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Candidate hidden state parameters
self.W_h = nn.Linear(input_dim, hidden_dim)
self.U_h = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, x_t, h_prev):
# 1. Update gate
z_t = torch.sigmoid(self.W_z(x_t) + self.U_z(h_prev))
# 2. Reset gate
r_t = torch.sigmoid(self.W_r(x_t) + self.U_r(h_prev))
# 3. Candidate hidden state
h_tilde = torch.tanh(self.W_h(x_t) + self.U_h(r_t * h_prev))
# 4. Final hidden state
h_t = (1 - z_t) * h_prev + z_t * h_tilde
return h_t
Even though the GRU doesn’t have a dedicated cell state, it can still preserve long-term memory because the update gate can create identity connections across time steps.
LSTM (Long Short-Term Memory) separates memory into two parts and uses three gates to carefully control what flows into, out of, and stays within these memories.
The three gates are:
Cell State Update (Long-Term Memory):
The new long-term memory is formed by keeping some of the old memory and adding some new information.
$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$Where $f_t$ is the forget gate, $i_t$ is the input gate, and $\tilde{c}_t$ is the candidate memory.
Candidate Cell State:
This is a proposal for what new information could be written into long-term memory.
$$\tilde{c}_t = \tanh(W_c x_t + U_c h_{t-1} + b_c)$$Hidden State Update (Short-Term Memory):
The short-term working memory is a filtered view of the long-term memory.
$$h_t = o_t \odot \tanh(c_t)$$Where $o_t$ is the output gate.
The Gates:
All gates are a mix of the input ($x_t$) and the previous hidden state ($h_{t-1}$), passed through a sigmoid function.
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(LSTMCell, self).__init__()
self.hidden_dim = hidden_dim
# Forget gate
self.W_f = nn.Linear(input_dim, hidden_dim)
self.U_f = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Input gate
self.W_i = nn.Linear(input_dim, hidden_dim)
self.U_i = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Candidate cell state
self.W_c = nn.Linear(input_dim, hidden_dim)
self.U_c = nn.Linear(hidden_dim, hidden_dim, bias=False)
# Output gate
self.W_o = nn.Linear(input_dim, hidden_dim)
self.U_o = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, x_t, h_prev, c_prev):
# 1. Forget gate
f_t = torch.sigmoid(self.W_f(x_t) + self.U_f(h_prev))
# 2. Input gate
i_t = torch.sigmoid(self.W_i(x_t) + self.U_i(h_prev))
# 3. Candidate memory
c_hat_t = torch.tanh(self.W_c(x_t) + self.U_c(h_prev))
# 4. Update cell state
c_t = f_t * c_prev + i_t * c_hat_t
# 5. Output gate
o_t = torch.sigmoid(self.W_o(x_t) + self.U_o(h_prev))
# 6. Hidden state (short-term memory)
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
GRUs typically achieve similar effects to LSTMs but with fewer parameters.
================ GRU =================
Input (x_t)
|
v
[ Reset Gate ] ---> controls old hidden state
|
v
Candidate Hidden (h̃_t)
|
[ Update Gate ] ---> blends old hidden h_{t-1} with h̃_t
|
v
Hidden State (h_t) ----> used as output + passed forward
================ LSTM =================
Input (x_t)
|
v
[ Forget Gate ] --------> controls old Cell State (c_{t-1})
[ Input Gate ] --------> controls Candidate Cell State (c̃_t)
|
v
Cell State (c_t) -------> long-term memory (carried forward)
|
[ Output Gate ] --------> controls what part of c_t is shown
|
v
Hidden State (h_t) ----> short-term memory, used as output
We typically stack multiple recurrent layers on top of each other. The first layer processes the input sequence and produces hidden states, and the second processes those hidden states to learn higher-level temporal features.
nn.LSTM and nn.GRU
# A 2-layer LSTM
lstm = nn.LSTM(
input_size=10, # size of input vector at each timestep
hidden_size=20, # hidden state dimension
num_layers=2, # stacked layers
batch_first=True # batch, seq_len, input_size format
)
The batch_first=True argument is convenient when working with other PyTorch layers like CNNs, which expect the batch dimension first.
The call signature is output, (hn, cn) = lstm(input, (h0, c0)).
output: The hidden state for all time steps from the final layer.hn, cn: The final hidden and cell states for each layer.
For a GRU, there is no cell state: output, hn = gru(input, h0).
There are many variants of RNNs for different tasks:
class ManyToManyAligned(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
out, _ = self.rnn(x) # out: (batch, seq_len, hidden_dim)
return self.fc(out) # (batch, seq_len, output_dim)
class Seq2Seq(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.decoder = nn.LSTM(output_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, src, tgt, teacher_forcing=True):
# Encode source
_, (h, c) = self.encoder(src)
# Decode target
out, _ = self.decoder(tgt, (h, c))
return self.fc(out) # (batch, tgt_seq_len, output_dim)
In Seq2Seq, teacher forcing is a training technique where the ground truth from the previous time step is fed as input to the decoder, rather than its own (potentially wrong) prediction. This makes training more stable and faster.
A standard RNN processes a sequence from left to right. A bidirectional RNN processes the sequence in both directions (forward and backward) using two separate RNNs and concatenates their hidden states. This gives the hidden state at time $t$ information from both the past and the future.
$$ \text{BiRNN}(x) = \text{ForwardRNN}(x) \oplus \text{BackwardRNN}(x) $$
class BiGRU(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
# Forward and backward GRU cells
self.gru_fwd = GRUCell(input_dim, hidden_dim)
self.gru_bwd = GRUCell(input_dim, hidden_dim)
def forward(self, x):
"""
x: [B, T, D] (batch, sequence length, input_dim)
"""
B, T, D = x.shape
h_fwd = torch.zeros(B, self.hidden_dim, device=x.device)
h_bwd = torch.zeros(B, self.hidden_dim, device=x.device)
outputs_fwd, outputs_bwd = [], []
# Forward direction
for t in range(T):
h_fwd = self.gru_fwd(x[:, t, :], h_fwd)
outputs_fwd.append(h_fwd)
# Backward direction
for t in reversed(range(T)):
h_bwd = self.gru_bwd(x[:, t, :], h_bwd)
outputs_bwd.append(h_bwd)
# Reverse backward outputs so they align with forward timesteps
outputs_bwd = outputs_bwd[::-1]
# Concatenate forward and backward hidden states
H_fwd = torch.stack(outputs_fwd, dim=1) # [B, T, H]
H_bwd = torch.stack(outputs_bwd, dim=1) # [B, T, H]
H = torch.cat([H_fwd, H_bwd], dim=-1) # [B, T, 2*H]
return H
In practice, you rarely backprop through entire long sequences due to memory constraints. With TBPTT, you cut the sequence into shorter chunks. You still carry the hidden state forward between chunks, but you only backpropagate within each chunk.
Many modern GRU/LSTM cells follow a residual + layernorm design. The residual connection helps with gradient flow, and LayerNorm normalizes activations across the hidden dimension of each time step, making training more stable for longer sequences.
import torch
import torch.nn as nn
class ResidualLayerNormGRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False):
super().__init__()
self.gru = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=bidirectional,
batch_first=True
)
self.layernorm = nn.LayerNorm(hidden_size * (2 if bidirectional else 1))
def forward(self, x, h0=None):
"""
x: [B, T, input_size]
h0: optional initial hidden state
"""
# GRU forward
out, h = self.gru(x, h0) # out: [B, T, H]
# Residual connection (if input and hidden dims match)
if out.shape[-1] == x.shape[-1]:
out = out + x
# Layer normalization
out = self.layernorm(out)
return out, h
Dropout in RNNs is subtle. It is applied on the outputs of each RNN layer, except the last layer. Crucially, it is not applied inside the recurrent cell across time steps. PyTorch samples one dropout mask per layer per forward pass and applies that same mask to the input of the next layer at every time step.
┌───────────────────────────────┐
│ GRU Layer 1 │
x_t ───► │ produces h1_t at each t │
└───────────────────────────────┘
│
▼
Dropout mask M1 (fixed across time)
│
▼
┌───────────────────────────────┐
│ GRU Layer 2 │
│ produces h2_t at each t │
└───────────────────────────────┘
│
▼
Dropout mask M2 (fixed across time)
│
▼
┌───────────────────────────────┐
│ GRU Layer 3 │
│ produces h3_t at each t │
└───────────────────────────────┘
│
▼
Output (no dropout here)