The Limits of Feedforward Networks & The Rise of RNNs

We all know what feedforward networks are. They assume a fixed-size input. It takes an input, multiplies it with some weights, and applies non-linearities. The order of the input doesn’t matter. This is a problem for input sequences, like words in a sentence, stock prices over days, or audio samples, where order actually matters. Different sentences can have different lengths, which would require padding or truncating. A feedforward network collapses the learned weights, losing the temporal structure.

With recurrence, instead of processing all input at once, you process one input at a time and carry forward some memory.

The Core RNN Equations

Hidden State:

$$h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)$$

Output:

$$y_t = W_{hy} h_t + b_y$$

The current hidden state depends on the previous hidden state. That’s what gives it memory.

Let’s take it one step at a time. At each timestep $t$:

  1. We take the input vector $x_t$ and apply the transformation $W_{xh} x_t$. This is what today’s input contributes to the memory.
  2. We take the previous hidden state $h_{t-1}$ and transform it using $W_{hh} h_{t-1}$. This is what we remember from yesterday.
  3. We combine these, add a bias, and apply a $\tanh$ non-linearity. That’s $h_t$, the memory at time $t$. It has today’s input and past memory.
  4. We then use another matrix $W_{hy}$ to turn the hidden state into the output: $y_t = W_{hy} h_t + b_y$. Based on my memory now, what should my prediction be?

Backpropagation Through Time (BPTT)

Unlike backpropagation that flows layer by layer, in an RNN we have time as well as depth. This is called backpropagation through time (BPTT).

Say the loss is cross-entropy: $L = \sum_{t=1}^{T} L_t(y_t, \hat{y}_t)$

When we unroll the RNN for $T$ timesteps, it looks like a deep forward network of depth $T$, where each layer reuses the same weights ($W_{xh}, W_{hh}, W_{hy}$). During backprop, gradients flow through the sequence of hidden states. Since parameters are shared, their gradients accumulate across timesteps.

The trouble comes from the recurrent connection. Since $h_t$ depends on $h_{t-1}$, this expands into a long chain rule across time.

Gradient Clipping

You can deal with exploding gradients using gradient clipping. This shrinks the gradient vector before updating the parameters. Typically, you take the L2 norm of all parameters flattened into one big vector.

Let $g = [\text{grad}(W_{xh}), \text{grad}(W_{hh}), \text{grad}(W_{hy}), \dots]$. The L2 norm is $\|g\|_2 = \sqrt{\sum_i g_i^2}$.

If this value is bigger than a threshold $\tau$, you rescale the entire gradient vector:

$$ \text{If } \|g\|_2 > \tau: \quad g \leftarrow g \cdot \frac{\tau}{\|g\|_2} $$

This keeps the direction of the gradient but reduces its magnitude.

Now, the main issue with RNNs was really vanishing gradients, which affects long-term memory. The old memory gets mixed in and there’s no control over how information passes to the future; the information washes out. The RNN forgets context.


Gated Recurrent Unit (GRU)

The GRU introduces gates to control what information to keep and what to forget.

For example, consider “Ramin wrote the best tutorial, he did it again!”. A vanilla RNN might overwrite “Ramin” and forget who “he” was. A GRU's update gate would say "don’t overwrite Ramin, hold on to that memory." If the update gate decides to just keep the old memory, the hidden state can flow unchanged for many steps. That means gradients can also flow back easily, solving the vanishing gradients problem.

Conceptually, the hidden state update is:

$$h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$

Where $z_t$ is the update gate. If $z_t=0$, then $h_t = h_{t-1}$. The memory is copied directly from one step to the next. No multiplication, no tanh—just an identity connection. During backprop, the gradients flow through this same path, unchanged. This is a shortcut path or linear highway for memory.

GRU Equations

There are 2 gates, both are little feedforward networks with a sigmoid activation, so their outputs are between 0 and 1.

  1. Update Gate ($z_t$): Keep the old hidden state or update it? $$z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z)$$ The sigmoid squashes the result to a number between 0 and 1, interpreted as a switch. 0 means keep old memory, 1 means overwrite with new memory.
  2. Reset Gate ($r_t$): When computing new candidate memory, how much of the old memory should I ignore? $$r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r)$$
  3. Candidate Hidden State ($\tilde{h}_t$): A proposal for new memory. The reset gate $r_t$ decides how much of the previous state $h_{t-1}$ is allowed in. If $r_t$ is zero, the candidate ignores past memory and depends only on the input. $$\tilde{h}_t = \tanh(W_h x_t + U_h (r_t \odot h_{t-1}) + b_h)$$
  4. Final Hidden State ($h_t$): The actual memory we carry forward. It’s a blend between the old hidden state and the candidate, controlled by the update gate. If $z_t$ is 0, it keeps the old memory ($h_t = h_{t-1}$), and if $z_t$ is 1, it replaces it ($h_t = \tilde{h}_t$). $$h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$

Finally, the output is computed from the hidden state:

$$y_t = W_{hy} h_t + b_y$$

GRU Implementation


import torch
import torch.nn as nn
import torch.nn.functional as F

class GRUCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GRUCell, self).__init__()
        self.hidden_dim = hidden_dim

        # Update gate parameters
        self.W_z = nn.Linear(input_dim, hidden_dim)
        self.U_z = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Reset gate parameters
        self.W_r = nn.Linear(input_dim, hidden_dim)
        self.U_r = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Candidate hidden state parameters
        self.W_h = nn.Linear(input_dim, hidden_dim)
        self.U_h = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x_t, h_prev):
        # 1. Update gate
        z_t = torch.sigmoid(self.W_z(x_t) + self.U_z(h_prev))

        # 2. Reset gate
        r_t = torch.sigmoid(self.W_r(x_t) + self.U_r(h_prev))

        # 3. Candidate hidden state
        h_tilde = torch.tanh(self.W_h(x_t) + self.U_h(r_t * h_prev))

        # 4. Final hidden state
        h_t = (1 - z_t) * h_prev + z_t * h_tilde
        return h_t

Even though the GRU doesn’t have a dedicated cell state, it can still preserve long-term memory because the update gate can create identity connections across time steps.


Long Short-Term Memory (LSTM)

LSTM (Long Short-Term Memory) separates memory into two parts and uses three gates to carefully control what flows into, out of, and stays within these memories.

The three gates are:

  1. Forget Gate: Decides what old info to erase from long-term memory.
  2. Input Gate: Decides what new info to store in long-term memory.
  3. Output Gate: Decides what part of the memory to use right now.

LSTM Equations

Cell State Update (Long-Term Memory):

The new long-term memory is formed by keeping some of the old memory and adding some new information.

$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$

Where $f_t$ is the forget gate, $i_t$ is the input gate, and $\tilde{c}_t$ is the candidate memory.

Candidate Cell State:

This is a proposal for what new information could be written into long-term memory.

$$\tilde{c}_t = \tanh(W_c x_t + U_c h_{t-1} + b_c)$$

Hidden State Update (Short-Term Memory):

The short-term working memory is a filtered view of the long-term memory.

$$h_t = o_t \odot \tanh(c_t)$$

Where $o_t$ is the output gate.

The Gates:

All gates are a mix of the input ($x_t$) and the previous hidden state ($h_{t-1}$), passed through a sigmoid function.

LSTM Implementation


import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LSTMCell, self).__init__()
        self.hidden_dim = hidden_dim

        # Forget gate
        self.W_f = nn.Linear(input_dim, hidden_dim)
        self.U_f = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Input gate
        self.W_i = nn.Linear(input_dim, hidden_dim)
        self.U_i = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Candidate cell state
        self.W_c = nn.Linear(input_dim, hidden_dim)
        self.U_c = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Output gate
        self.W_o = nn.Linear(input_dim, hidden_dim)
        self.U_o = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def forward(self, x_t, h_prev, c_prev):
        # 1. Forget gate
        f_t = torch.sigmoid(self.W_f(x_t) + self.U_f(h_prev))

        # 2. Input gate
        i_t = torch.sigmoid(self.W_i(x_t) + self.U_i(h_prev))

        # 3. Candidate memory
        c_hat_t = torch.tanh(self.W_c(x_t) + self.U_c(h_prev))

        # 4. Update cell state
        c_t = f_t * c_prev + i_t * c_hat_t

        # 5. Output gate
        o_t = torch.sigmoid(self.W_o(x_t) + self.U_o(h_prev))

        # 6. Hidden state (short-term memory)
        h_t = o_t * torch.tanh(c_t)
        
        return h_t, c_t

GRU vs. LSTM at a Glance

GRUs typically achieve similar effects to LSTMs but with fewer parameters.


================ GRU =================

Input (x_t)
   |
   v
 [ Reset Gate ] ---> controls old hidden state
   |
   v
Candidate Hidden (h̃_t)
   |
[ Update Gate ] ---> blends old hidden h_{t-1} with h̃_t
   |
   v
Hidden State (h_t) ----> used as output + passed forward

================ LSTM =================

Input (x_t)
   |
   v
[ Forget Gate ] --------> controls old Cell State (c_{t-1})
[ Input Gate  ] --------> controls Candidate Cell State (c̃_t)
   |
   v
Cell State (c_t) -------> long-term memory (carried forward)
   |
[ Output Gate ] --------> controls what part of c_t is shown
   |
   v
Hidden State (h_t) ----> short-term memory, used as output

Practical Implementations & Concepts in PyTorch

We typically stack multiple recurrent layers on top of each other. The first layer processes the input sequence and produces hidden states, and the second processes those hidden states to learn higher-level temporal features.

PyTorch nn.LSTM and nn.GRU


# A 2-layer LSTM
lstm = nn.LSTM(
    input_size=10,   # size of input vector at each timestep
    hidden_size=20,  # hidden state dimension
    num_layers=2,    # stacked layers
    batch_first=True # batch, seq_len, input_size format
)

The batch_first=True argument is convenient when working with other PyTorch layers like CNNs, which expect the batch dimension first.

The call signature is output, (hn, cn) = lstm(input, (h0, c0)).

For a GRU, there is no cell state: output, hn = gru(input, h0).

Common RNN Architectures

There are many variants of RNNs for different tasks:


class ManyToManyAligned(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out, _ = self.rnn(x)      # out: (batch, seq_len, hidden_dim)
        return self.fc(out)       # (batch, seq_len, output_dim)

class Seq2Seq(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(output_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, src, tgt, teacher_forcing=True):
        # Encode source
        _, (h, c) = self.encoder(src)

        # Decode target
        out, _ = self.decoder(tgt, (h, c))
        return self.fc(out)  # (batch, tgt_seq_len, output_dim)

In Seq2Seq, teacher forcing is a training technique where the ground truth from the previous time step is fed as input to the decoder, rather than its own (potentially wrong) prediction. This makes training more stable and faster.

Bidirectional RNNs

A standard RNN processes a sequence from left to right. A bidirectional RNN processes the sequence in both directions (forward and backward) using two separate RNNs and concatenates their hidden states. This gives the hidden state at time $t$ information from both the past and the future.

$$ \text{BiRNN}(x) = \text{ForwardRNN}(x) \oplus \text{BackwardRNN}(x) $$

class BiGRU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Forward and backward GRU cells
        self.gru_fwd = GRUCell(input_dim, hidden_dim)
        self.gru_bwd = GRUCell(input_dim, hidden_dim)

    def forward(self, x):
        """
        x: [B, T, D]  (batch, sequence length, input_dim)
        """
        B, T, D = x.shape
        h_fwd = torch.zeros(B, self.hidden_dim, device=x.device)
        h_bwd = torch.zeros(B, self.hidden_dim, device=x.device)

        outputs_fwd, outputs_bwd = [], []

        # Forward direction
        for t in range(T):
            h_fwd = self.gru_fwd(x[:, t, :], h_fwd)
            outputs_fwd.append(h_fwd)

        # Backward direction
        for t in reversed(range(T)):
            h_bwd = self.gru_bwd(x[:, t, :], h_bwd)
            outputs_bwd.append(h_bwd)

        # Reverse backward outputs so they align with forward timesteps
        outputs_bwd = outputs_bwd[::-1]

        # Concatenate forward and backward hidden states
        H_fwd = torch.stack(outputs_fwd, dim=1)  # [B, T, H]
        H_bwd = torch.stack(outputs_bwd, dim=1)  # [B, T, H]
        H = torch.cat([H_fwd, H_bwd], dim=-1)     # [B, T, 2*H]

        return H

Truncated Backpropagation Through Time (TBPTT)

In practice, you rarely backprop through entire long sequences due to memory constraints. With TBPTT, you cut the sequence into shorter chunks. You still carry the hidden state forward between chunks, but you only backpropagate within each chunk.


Modern RNN Enhancements

Residuals + LayerNorm

Many modern GRU/LSTM cells follow a residual + layernorm design. The residual connection helps with gradient flow, and LayerNorm normalizes activations across the hidden dimension of each time step, making training more stable for longer sequences.


import torch
import torch.nn as nn

class ResidualLayerNormGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.layernorm = nn.LayerNorm(hidden_size * (2 if bidirectional else 1))

    def forward(self, x, h0=None):
        """
        x: [B, T, input_size]
        h0: optional initial hidden state
        """
        # GRU forward
        out, h = self.gru(x, h0)  # out: [B, T, H]

        # Residual connection (if input and hidden dims match)
        if out.shape[-1] == x.shape[-1]:
            out = out + x

        # Layer normalization
        out = self.layernorm(out)

        return out, h

Dropout in RNNs

Dropout in RNNs is subtle. It is applied on the outputs of each RNN layer, except the last layer. Crucially, it is not applied inside the recurrent cell across time steps. PyTorch samples one dropout mask per layer per forward pass and applies that same mask to the input of the next layer at every time step.


          ┌───────────────────────────────┐
          │         GRU Layer 1           │
x_t ───►  │     produces h1_t at each t   │
          └───────────────────────────────┘
                        │
                        ▼
          Dropout mask M1 (fixed across time)
                        │
                        ▼
          ┌───────────────────────────────┐
          │         GRU Layer 2           │
          │     produces h2_t at each t   │
          └───────────────────────────────┘
                        │
                        ▼
          Dropout mask M2 (fixed across time)
                        │
                        ▼
          ┌───────────────────────────────┐
          │         GRU Layer 3           │
          │     produces h3_t at each t   │
          └───────────────────────────────┘
                        │
                        ▼
                 Output (no dropout here)