Deep Learning Loss Functions

You know me, I’m gonna talk about loss functions now! Let’s start with binary cross entropy.

1. Binary Cross Entropy (BCE)

The formula for BCE is defined as:

$$ L = -\sum_i (y_i \log(p'_i) + (1-y_i)\log(1-p'_i)) $$

Where $p'_i$ is the $\text{sigmoid}(\text{logit}_i)$ and $y_i$ is whether a class was present (ground truth).

First, you squash logits into probabilities using sigmoid. Then, you compare that prediction with ground truth using BCE loss. With BCE, each class is an independent binary prediction. You can even extend this to multiple labels. You can pass each logit through a sigmoid, compute binary cross entropy per class, and average over classes. This lets you model multiple classes co-occurring.

With BCE you can’t compress the label into one integer anymore, so you need to have a multi-hot vector. You must determine the cross-entropy for each class and then sum over them.

$$ BCE(y,p) = -\sum_{i=1}^C(y_i \log(p_i) + (1-y_i) \log(1-p_i)) $$

This checks every class independently. BCE needs a multi-hot vector because each class is treated as an independent yes/no problem.

def bce_with_logits_loss(logits, targets):
    """
    logits: (N, C) raw outputs
    targets: (N, C) multi-hot labels {0,1}
    """
    # sigmoid
    probs = 1 / (1 + torch.exp(-logits))

    # binary cross entropy per class
    bce = -(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))

    # average over classes and batch
    loss = bce.mean()
    return loss

2. Categorical Cross Entropy

Categorical cross entropy is a special case of BCE with a one-hot target. For classes with logits, and probabilities $p_i = \text{sigmoid}(z_i)$ and target in 0 and 1. Only one class is true at a time.

With categorical cross entropy you apply a softmax over logits and get a probability over $C$ classes.

$$ P_i = \frac{\exp(z_j)}{\sum_j(\exp(z_j))} $$

Then you apply categorical cross-entropy against a one-hot target. If the predicted probability for the correct class $k$ is $p$, then CE is:

$$ CE = -\log P(\text{true\_class}) $$

It only looks at the correct class’s probability. CE assumes mutual exclusivity—that exactly one class is correct. So, that’s why you get categorical cross entropy, you can show it as a single index. Internally, PyTorch converts that into a one-hot vector. The loss just turns into taking the probability of the correct class.

Remember negative log means lower loss is higher probability; it’s just like the negative log likelihood of the true class under the predicted distribution. If the model was super sure, like 0.9, then the negative loss of that would be 0.1 and if it wasn’t sure like 0.1 it would be around 2.3. CE punishes the model when it assigns low probability to the true class.

def cross_entropy_loss(logits, target_index):
    """
    logits: (N, C) raw outputs
    target_index: (N,) integer labels (0 <= y < C)
    """
    # softmax
    exp_logits = torch.exp(logits)
    probs = exp_logits / exp_logits.sum(dim=1, keepdim=True)

    # pick probability of the correct class
    correct_class_probs = probs[torch.arange(len(target_index)), target_index]

    # negative log likelihood
    loss = -torch.log(correct_class_probs).mean()
    return loss

3. The Alignment Problem

If we had a sequence of input frames, and we knew the label for each frame (we don’t here) then we could just apply CE at each frame and average.

$$ \text{Loss} = -\frac{1}{T} \sum_{t=1}^T(\log(p_{(t,y_t)})) $$

Where $y_t$ is the ground truth label for frame $t$. But this assumes a frame-to-label alignment. We don’t have labels for every frame and the target tokens are much shorter than input. The speech sequence length typically 1000 frames for 10 sec is much longer than the text token, maybe 20 words and we don’t know frame to token alignment.

We need a loss that doesn’t require explicit alignment and can collapse variable length frame prediction into a shorter token length. This is where “Connectionist Temporal Classification” (CTC) loss comes to play.

4. Connectionist Temporal Classification (CTC)

CTC augments the vocabulary with a black token, no label. This allows the frames to correspond to no output symbols. So, now you have $T$ frames and each frame gets a probability distribution over $V+1$ (tokens + label).

To get final transcriptions from a frame sequence you have to merge their predictions. You merge repeated tokens and remove blanks. CTC loss uses dynamic programming to sum the probabilities of all possible alignments between frames and target tokens and does this efficiently.

Merging Predictions

Each frame has a probability mass over many tokens and thus can create multiple alignments. We sum over all valid alignments that lead to ground truth. But, this is a lot of alignments. Over $T$ frames with $|V| + 1$ possible paths, you get $(|V| + 1)^T$ paths. Even with a small vocabulary of 50 you could get $50^{100}$ paths.

With CTC model doesn’t output one label, it outputs a probability distribution over all possible labels at each frame. Alignments are not generated arbitrarily from predictions, they are constrained by the ground truth. Alignment is one possible frame level sequence that when collapsed equal to the ground truth.

The predictions supply the probability at each frame and the ground truth constrain which sequence of choices we care about. We compute the total probability of the ground truth string by summing over all valid alignment paths.

$$ P(y | x ) = \sum_{\pi \in B^{-1}(y)} \prod_{t=1}^T (p(\pi_t|x_t)) $$

Since multiple alignments can collapse to the same label sequence $y$, the total probability of $y$ is the sum of all alignment probabilities.

The CTC loss is then just negative loss likelihood:

$$ L_{ctc} = -\log P(y|x) $$

Well that’s a lot of math that you probably don’t need, but the big idea here is that we need to find the alignments that collapse to the target word and use our model’s probability predictions to score each valid alignment. CTC doesn’t enumerate every single alignment though. It uses forward-backward dynamic programming algorithm to efficiently sum their probabilities.

Dynamic Programming for CTC

At each input you can output either a symbol or a blank token. A valid alignment is any sequence of that collapse to the ground truth after removing blanks and merging consecutive duplicates. You can think of this as a leetcode hard problem, counting paths. We need to count how many paths collapse to the target sequence and once we have that we can replace those counts with probabilities.

A naive CTC would be a softmax with repeat collapse. You consider the single argmax alignment and collapse repeats accordingly. It’s simple and fast and it doesn’t work! It assumes each character or phoneme is well aligned in time or clearly segmented. Misalignment and variable timing will cause errors.

Example

Let’s say we have 10 frames and take a small example like CAT and say our vocabulary is $\{C, A, T, blank\}$. Valid alignments are:

We’re basically trying to insert our letters into 4 frames in order. If we had 10 frames and 4 letters that would be $4^{10}$ sequences, although most would not collapse to CAT.

CTC builds an extended target sequence with blanks inserted: $Y’ = \_C\_A\_T\_$.

Where length is $2U + 1$. Now let’s define a DP table. Our original target of length 3 has been extended to 7. Blanks are guard rails, they allow for variable timing and disambiguating repeats. Now think of aligning the 10 frames to the 7 positions in the extended sequence. At each frame you stay on the same symbol or move to the next symbol. Each step you can:

  1. Stay at the same position
  2. Move forward by 1 (always)
  3. Move forward by 2 (only if skipping a blank between different symbols)

This is like a lattice, each node is $(t,s)$ at time $t$ and target position $s$. So, now our table looks like 10 frames/rows by 7 alignments/columns, ($T \times 2U+1$). Number of rows are your input length and columns are the extended target length. Time axis is how far we have processed the input and the extended target axis is how far we’ve matched in to the target string. A path through this grid is one alignment.

From $(t,s)$ the dp says what is the probability (or count) of being aligned to this position $s$ in the extended target.

Forward Implementation

The transition rule: To compute $dp[t][s]$ you could get contribution from the same symbol $(t-1,s)$ or advance by 1 $(t-1,s-1)$ or skipped over the blank token if valid $(t-1, s-2)$. This is the forward algorithm. After filling the whole table the number of valid alignment is: $dp[T][s-2] + dp[T][s-1]$ that is ending at the final character or a trailing blank.

import numpy as np

def ctc_forward(probs, target):
    """
    probs: array (T, V) with softmax probs for each frame
           V includes all symbols + blank
    target: list of symbols, e.g. ['C','A','T']
    """
    # 1. Build extended target
    extended = ['_']
    for ch in target:
        extended += [ch, '_']
    S = len(extended)

    T = probs.shape[0]
    alpha = np.zeros((T, S))

    # 2. Initialize first frame
    alpha[0][0] = probs[0][extended[0]]   # prob of blank at t=0
    alpha[0][1] = probs[0][extended[1]]   # prob of 'C' at t=0

    # 3. Fill DP table
    for t in range(1, T):
        for s in range(S):
            stay = alpha[t-1][s]
            prev = alpha[t-1][s-1] if s-1 >= 0 else 0
            skip = alpha[t-1][s-2] if (s-2 >= 0 and extended[s] != extended[s-2]) else 0
            alpha[t][s] = probs[t][extended[s]] * (stay + prev + skip)

    # 4. Final probability
    return alpha[T-1][S-1] + alpha[T-1][S-2]

This is our $p(y|x)$, we just efficiently computed the sum. This is the probability of target under the model.

5. Training with Forward-Backward

But, we also need gradients w.r.t. to the model’s softmax outputs at each frame during training. We need to know at frame 1 how much probability was aligned “A” vs blank. This is the posterior probability.

The forward gives you the prefix probability, probability mass up to $(t,s)$. The backward gives you suffix probability, probability mass from $(t,s)$ to the end. When you use both you get how likely is that frame $t$ really belong to target position $s$?

import numpy as np

def ctc_forward_backward(probs, target):
    """
    probs: np.array of shape (T, V) with softmax probabilities
           (V includes all symbols + blank, accessed by dict)
    target: list of symbols, e.g. ['C','A','T']
    """
    # Build extended target with blanks
    extended = ['_']
    for ch in target:
        extended += [ch, '_']
    S = len(extended)   # extended length
    T = probs.shape[0]  # number of frames

    # Map symbol -> index in probs
    vocab = {ch: i for i, ch in enumerate(set(extended))}

    # Forward DP
    alpha = np.zeros((T, S))
    alpha[0][0] = probs[0, vocab[extended[0]]]
    alpha[0][1] = probs[0, vocab[extended[1]]]

    for t in range(1, T):
        for s in range(S):
            stay = alpha[t-1][s]
            prev = alpha[t-1][s-1] if s-1 >= 0 else 0
            skip = alpha[t-1][s-2] if (s-2 >= 0 and extended[s] != extended[s-2]) else 0
            alpha[t][s] = probs[t, vocab[extended[s]]] * (stay + prev + skip)

    # Backward DP
    beta = np.zeros((T, S))
    # initialize last frame
    beta[T-1][S-1] = probs[T-1, vocab[extended[S-1]]]   # last blank
    beta[T-1][S-2] = probs[T-1, vocab[extended[S-2]]]   # last symbol

    for t in reversed(range(T-1)):  # from T-2 down to 0
        for s in range(S):
            same = beta[t+1][s]
            nxt = beta[t+1][s+1] if s+1 < S else 0
            skip = beta[t+1][s+2] if (s+2 < S and extended[s] != extended[s+2]) else 0
            beta[t][s] = probs[t, vocab[extended[s]]] * (same + nxt + skip)

    # Total probability (should match forward and backward)
    total_prob = alpha[T-1][S-1] + alpha[T-1][S-2]
    
    return alpha, beta, total_prob, extended

6. Inference: Beam Search

At inference, we need to figure out which output sequence to return. If the model is confident and the output space is small or constrained then a greedy decoding is okay. But, that would throw away a lot of probability mass. Beam search recovers some of it by merging multiple paths that collapse to the same candidate sequence.

def simple_ctc_beam_search(probs, beam_size=3, top_k=3):
    """
    probs: (T, V) softmax probabilities over vocab (incl blank "_")
    """
    beams = {"": 1.0}  # start with empty prefix and prob=1

    for t in range(len(probs)):
        new_beams = {}
        # pick top_k symbols at this frame
        top_syms = np.argsort(probs[t])[-top_k:]
        for prefix, prob in beams.items():
            for sym in top_syms:
                p = probs[t][sym]
                if sym == "_":
                    new_prefix = prefix
                else:
                    # collapse rule: if repeat, don't double add
                    if prefix.endswith(sym):
                        new_prefix = prefix  # same seq
                    else:
                        new_prefix = prefix + sym
                # accumulate probability
                new_beams[new_prefix] = new_beams.get(new_prefix, 0) + prob * p
        # prune to top beam_size
        beams = dict(sorted(new_beams.items(), key=lambda x: -x[1])[:beam_size])

    return max(beams.items(), key=lambda x: x[1])