Self-Supervised Learning (SSL) for Time Series Data

Self-Supervised Learning (SSL) is a type of representation learning where models learn from unlabeled data by generating their own training signals (pseudo-labels) using **pretext tasks**. The learned features are then utilized for downstream tasks. This tutorial will focus on SSL applied to time series data.

SSL asks the model to predict one part of the data from another. This forces the model to learn meaningful internal data structures without requiring human labels. A typical SSL pipeline involves:

  1. Taking raw, unlabeled data.
  2. Creating a pretext task the model must solve (e.g., mask prediction, contrastive matching).
  3. Learning robust features by solving that task.
  4. Applying the learned features to a downstream task, such as classification or detection.

There are five primary categories of pretext tasks:

While SSL has been primarily useful in building Large Language Models (LLMs), it has recently found its way into sensor data applications (e.g., https://arxiv.org/abs/2410.13638). Time series data, such as EEG and ECoG, contains rich temporal and spatial patterns. SSL can learn generalized representations from this data that transfer well to new tasks and handle missing or noisy data more robustly.

In a nutshell, SSL removes reliance on human labels by learning rich and transferable features through pre-training on large, uncurated datasets.


Reconstruction-Based SSL

If a model can accurately reconstruct missing or masked input, it implies it has learned meaningful patterns in the data. Reconstruction tasks focus the model on learning temporal dependencies, local and global structure, and feature relationships within the data.

Autoencoders

Autoencoders are commonly used for reconstruction-based SSL. A simple autoencoder compresses the input into a lower-dimensional latent representation and then reconstructs the original input from this representation, typically using Mean Squared Error (MSE) loss.

$$ \text{x} \longrightarrow \text{Encoder} \longrightarrow \text{z} \longrightarrow \text{Decoder} \longrightarrow \hat{\text{x}} $$

Masked Autoencoders (MAE) extend this by partitioning the input into patches, randomly masking a subset, encoding *only* the visible patches, and having the decoder reconstruct *all* patches. However, the loss is computed only on the masked portions using MSE.

For time series data, the input is a segment with shape [channels x time]. The encoder compresses this into a latent representation, and the decoder reconstructs the original signal.

If the input x has shape [batch_size, channels, time], we can use Conv1d layers in the encoder to downsample and extract temporal features, learning a compact latent representation (e.g., 32 dimensions). Then, ConvTranspose1d layers in the decoder can reconstruct the original input shape.

$$ \text{Input x} \longrightarrow \text{Encoder (Conv1D / Linear)} \longrightarrow \text{Latent z} \longrightarrow \text{Decoder} \longrightarrow \text{Reconstructed } \hat{\text{x}} $$

What is ConvTranspose1d?

Before diving back into the autoencoder, let's briefly understand ConvTranspose1d. It is the inverse operation of Conv1d. While Conv1d typically downsamples the time dimension, ConvTranspose1d upsamples it. It's also known as deconvolution, transpose convolution, or learned upsampling. For example, if you downsample with stride = 2, you can upsample by a factor of 2 using ConvTranspose1d.

Unlike simple upsampling (which isn't learnable), ConvTranspose1d achieves upsampling by effectively inserting zeros between elements and then applying a learnable kernel to fill in these missing parts. In the provided scratch implementation, original values are inserted into zeros to create space for the model to learn what values should fill the gaps. A standard convolution is then run on this upsampled signal with appropriate padding, allowing the output to expand to the desired size. During training, the model learns kernel weights to reconstruct meaningful values in the upsampled output, spreading and blending values into a longer signal.


class ConvTranspose1dScratch(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.stride = stride
        self.kernel_size = kernel_size
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels))

    def forward(self, x):
        B, C_in, L_in = x.shape
        C_out = self.weight.shape[0]

        # Step 1: Insert zeros between time steps
        L_upsampled = (L_in - 1) * self.stride + 1
        x_upsampled = torch.zeros(B, C_in, L_upsampled, device=x.device)
        x_upsampled[:, :, ::self.stride] = x # insert zeros between steps
        # Step 2: Perform normal convolution over upsampled signal
        # This is regular F.conv1d with flipped weights (like in transposed conv)
        out = F.conv1d(x_upsampled, self.weight, self.bias, stride=1, padding=self.kernel_size - 1)

        return out
    

Now, back to our autoencoder. The encoder extracts low and high-level features into a shorter and deeper representation, forming a latent space that captures the signal’s essential structure. The decoder then effectively doubles the time steps in each layer by refining features. It learns how to fill the signal and recover the original time series from the latent feature. The network learns to preserve the shape and structure of the original signal using MSE loss.


import torch
import torch.nn as nn

class TimeSeriesAutoencoder(nn.Module):
    def __init__(self, input_channels=1, latent_dim=64):
        super().__init__()
        
        # Encoder: downsample with Conv1d
        self.encoder = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=5, stride=2, padding=2), # [B, 16, T/2]
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=5, stride=2, padding=2), # [B, 32, T/4]
            nn.ReLU(),
            nn.Conv1d(32, latent_dim, kernel_size=5, stride=2, padding=2), # [B, latent_dim, T/8]
            nn.ReLU(),
        )

        # Decoder: upsample with ConvTranspose1d
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(latent_dim, 32, kernel_size=4, stride=2, padding=1), # [B, 32, T/4]
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1), # [B, 16, T/2]
            nn.ReLU(),
            nn.ConvTranspose1d(16, input_channels, kernel_size=4, stride=2, padding=1), # [B, 1, T]
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
    

Masked Autoencoders (MAE)

An **MAE (Masked Autoencoder)** model only sees a portion of the input signal (e.g., 25%) and learns to reconstruct the entire signal, specifically focusing on the masked parts. The encoder takes the partial input, encodes it, and the decoder reconstructs the whole signal. The MSE loss is applied only to the masked values.

The first step in MAE is converting the time signal into non-overlapping patches, called **patch embeddings**. This is typically done using a Conv1d layer with a stride equal to the kernel size.


# ----------------------
# Patch Embedding Module
# ----------------------
class PatchEmbed1D(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        # Conv1d with kernel and stride equal to patch_size extracts non-overlapping patches
        self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, T]
        x = self.proj(x) # [B, embed_dim, T//patch_size] - patches as features
        x = x.transpose(1, 2) # [B, N_patches, embed_dim] - N_patches is now sequence length
        return x
    

Next, we randomly keep a subset of these patches (e.g., 25%). We keep track of the original indices of both kept and masked patches to restore the original order later. The encoder transforms only the visible patches, resulting in an output of shape [B, N_visible, embed_dim]. The TransformerEncoderLayer typically contains multi-head self-attention, a feed-forward network, two layer norms, and residual connections. A full MAE Encoder would stack multiple such layers using nn.TransformerEncoder.


# ----------------------
# MAE Encoder
# ----------------------
class MAEEncoder(nn.Module):
    def __init__(self, embed_dim, depth):
        super().__init__()
        # Use a standard TransformerEncoderLayer as the base layer
        layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, dim_feedforward=embed_dim*4)
        # Stack multiple layers using TransformerEncoder
        self.encoder = nn.TransformerEncoder(layer, num_layers=depth)

    def forward(self, x):
        # TransformerEncoder expects input as [SequenceLength, BatchSize, EmbedDim]
        x = x.transpose(0, 1) # [N_patches, B, embed_dim]
        x = self.encoder(x)
        return x.transpose(0, 1) # [B, N_patches, embed_dim] - revert to original batch-first
    

The random_masking function selects a random subset of patches to keep. For each sample in the batch, it generates random noise, shuffles indices based on this noise, and selects the first len_keep indices as the visible patches. torch.gather() is used to select these patches from the input, resulting in [B, N_keep, D].


# ----------------------
# Random Masking
# ----------------------
def random_masking(x, mask_ratio):
    B, N, D = x.shape # B: batch_size, N: num_patches, D: embed_dim
    len_keep = int(N * (1 - mask_ratio))

    noise = torch.rand(B, N, device=x.device) # Random noise for shuffling
    ids_shuffle = torch.argsort(noise, dim=1) # Get shuffled indices
    ids_restore = torch.argsort(ids_shuffle, dim=1) # Get indices to restore original order

    ids_keep = ids_shuffle[:, :len_keep] # Select indices of patches to keep
    # Use gather to select the visible patches
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))

    return x_masked, ids_restore, ids_keep
    

The decoder reconstructs masked patches using learnable mask tokens. The MAE decoder does not inherently know which patches were visible or masked in the encoder's input. Instead, it receives a sequence composed of the visible encoded patches concatenated with a set of learnable mask tokens for the missing patches. A shared, learnable mask token (e.g., [1, 1, embed_dim]) is expanded to [B, N_masked, embed_dim] and concatenated with the visible patches to form a full sequence ([B, N_visible + N_masked, embed_dim]). The self.proj layer maps the encoder output to the decoder dimension if they differ. Positional encodings are then added to each token in the reordered sequence, and the full sequence is passed through a Transformer decoder. This encourages the decoder to learn global dependencies between all patches (masked and visible). A final linear layer (self.head) maps the decoder output back to the original patch size ([B, N_patches, patch_dim]), predicting raw values for all patches (even visible ones), but the MSE loss is computed *only* on the masked indices.


# ----------------------
# MAE Decoder
# ----------------------
class MAEDecoder(nn.Module):
    def __init__(self, embed_dim, decoder_dim, patch_dim, depth, num_patches):
        super().__init__()
        self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_dim))
        self.proj = nn.Linear(embed_dim, decoder_dim) # Project encoder output to decoder_dim if needed
        
        # Decoder is also a TransformerEncoder stack
        layer = nn.TransformerEncoderLayer(d_model=decoder_dim, nhead=4, dim_feedforward=decoder_dim*4)
        self.decoder = nn.TransformerEncoder(layer, num_layers=depth)
        self.head = nn.Linear(decoder_dim, patch_dim) # Maps decoder output to original patch pixel values

    def forward(self, x_encoded, ids_restore):
        B, N_vis, _ = x_encoded.shape
        N_total = ids_restore.shape[1]
        N_mask = N_total - N_vis

        # Project encoded visible patches to decoder dimension
        x_vis = self.proj(x_encoded)
        # Expand mask tokens to batch size and number of masked patches
        mask_tokens = self.mask_token.expand(B, N_mask, -1)

        # Concatenate visible and mask tokens, then restore original order
        x_full = torch.cat([x_vis, mask_tokens], dim=1)
        x_full = torch.gather(x_full, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x_full.size(-1)))
        
        # Add positional embeddings to the full sequence
        x_full = x_full + self.pos_embed[:, :N_total]

        # Pass through the decoder Transformer
        # TransformerEncoder expects (SeqLen, Batch, EmbedDim)
        x_full = self.decoder(x_full.transpose(0,1)).transpose(0,1)
        
        # Project to reconstruct patch pixel values
        return self.head(x_full) # [B, N_patches, patch_dim]
    

And here's the full MAE model:


# ----------------------
# Full MAE Model
# ----------------------
class MAEModel(nn.Module):
    def __init__(self, in_channels=1, patch_size=16, embed_dim=128, encoder_depth=4, decoder_dim=64, decoder_depth=2):
        super().__init__()
        self.patch_embed = PatchEmbed1D(in_channels, embed_dim, patch_size)
        self.encoder = MAEEncoder(embed_dim, encoder_depth)
        # The decoder's head must output the flat pixel values of a patch
        self.decoder = MAEDecoder(embed_dim, decoder_dim, patch_size * in_channels, decoder_depth, num_patches=400 // patch_size)

    def forward(self, x, mask_ratio=0.75):
        # 1. Convert input signal to patches
        patches = self.patch_embed(x) # [B, N_patches, embed_dim]
        # 2. Randomly mask patches
        x_masked, ids_restore, ids_keep = random_masking(patches, mask_ratio)
        # 3. Encode visible patches
        encoded = self.encoder(x_masked)
        # 4. Decode to reconstruct all patches
        pred = self.decoder(encoded, ids_restore) # [B, N_patches, patch_dim]
        return pred, patches, ids_restore # Return predictions, original patches, and restore indices
    

And here's the training loop. Let's assume we have a single-channel time series with 400 time steps, and we extract patches of size 16 samples. The decoder must rely on the context from visible patches to infer the values of the masked ones. This encourages a global understanding of the sequence, similar to BERT's masked language modeling, where the model infers masked words from context. In the loss calculation, we start with a mask indicating all patches are masked. Then, we "unmask" the visible patches by setting their mask values to zero at the corresponding ids_restore positions. Finally, we element-wise multiply the squared prediction error by this mask to compute the MSE loss *only* over the masked tokens.


import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random # For augmentations

# Assuming you have a dataset: each sample is [C, T]
class DummyDataset(Dataset):
    def __getitem__(self, idx):
        return torch.randn(1, 400) # [C, T] - Example: 1 channel, 400 time steps
    def __len__(self):
        return 1000

dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = MAEModel(in_channels=1, patch_size=16) # num_patches will be 400 // 16 = 25
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss(reduction='none') # Use reduction='none' to apply mask manually

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for epoch in range(10):
    model.train()
    total_loss = 0.0

    for x in loader:
        x = x.to(device) # [B, C, T]

        # Forward pass through MAE model
        pred, target_patches, ids_restore = model(x, mask_ratio=0.75) # pred and target_patches are [B, N_patches, patch_dim]

        # Prepare mask for loss calculation (only on masked patches)
        B, N_patches, patch_dim = target_patches.shape
        len_keep = int(N_patches * (1 - 0.75)) # 0.25 is keep_ratio

        mask = torch.ones(B, N_patches, device=x.device) # Initialize mask to all ones (all masked)
        # Set mask values to 0 for the *kept* patches (since ids_restore maps shuffled to original)
        mask.scatter_(1, ids_restore[:, :len_keep], 0) # Scatter 0s at kept indices
        mask = mask.unsqueeze(-1).expand(-1, -1, patch_dim) # Expand mask to match patch_dim

        # Compute squared error for all patches
        loss_all_patches = (pred - target_patches) ** 2
        # Apply the mask: only count error for masked patches (where mask is 1)
        loss = (loss_all_patches * mask).sum() / mask.sum() # Sum and then divide by number of masked elements

        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")
    

Contrastive Learning

The core idea of **Contrastive Learning** is to make similar samples more similar in the latent space and dissimilar samples less similar (i.e., push them farther apart). There are several variations of contrastive learning.

SimCLR (Simple Contrastive Learning of Representations)

SimCLR involves applying two different random augmentations to the same input signal. Both augmented "views" are then passed through an encoder, and the goal is to maximize the similarity between their resulting latent representations. This forces the encoder to learn invariant features.

First, we define the augmentations. These augmentations preserve the identity of the signal but make it appear different enough to challenge the model. For time series, examples include:

These augmentations create two different versions of the same signal, which serve as **positive pairs** for contrastive learning.


import random # for time_crop

# ------------------------
# Time-series augmentations
# ------------------------
def time_crop(x, crop_ratio=0.8):
    B, C, T = x.shape
    new_T = int(T * crop_ratio)
    if new_T >= T: # Ensure new_T is not larger than T if crop_ratio is too high or T is small
        return x
    start = random.randint(0, T - new_T)
    return x[:, :, start:start + new_T]

def time_jitter(x, sigma=0.01):
    return x + sigma * torch.randn_like(x)

def augment(x):
    # Apply jitter first, then crop, or vice-versa, depending on desired effect
    # Ensure dimensions match after crop for downstream models if using fixed input size
    # For simplicity, let's ensure the output has the original time dimension
    # This example needs to be carefully adapted for models expecting fixed T
    # For now, let's assume time_crop handles padding or downstream adjusts.
    # A more robust augment for fixed-size models would include padding after crop.
    
    # For SimCLR, the original paper often resizes views to a common size for embedding.
    # Here, for simplicity, we assume the encoder handles variable lengths or output is pooled.
    cropped_x = time_crop(x)
    return time_jitter(cropped_x)
    

SimCLR processes two augmented views of each input. Both views pass through the same encoder and then a projection head. The SimCLRModel includes a backbone encoder (e.g., Conv1d layers for feature extraction and downsampling, followed by AdaptiveAvgPool1d and Flatten to produce a fixed-length feature vector) and a ProjectionHead. The SimCLR paper demonstrated that projection heads improve contrastive learning, but for downstream tasks, the projection head is often discarded during fine-tuning.


# ------------------------
# SimCLR Projection Head
# ------------------------
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=256, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)

# ------------------------
# SimCLR Encoder + Projection
# ------------------------
class SimCLRModel(nn.Module):
    def __init__(self, in_channels=1, encoder_dim=256):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, 64, kernel_size=5, stride=2, padding=2), # Output: [B, 64, T/2]
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),       # Output: [B, 128, T/4]
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1), # Pools across time dimension to get [B, 128, 1]
            nn.Flatten(),            # Flattens to [B, 128]
            nn.Linear(128, encoder_dim) # Projects to final encoder_dim [B, encoder_dim]
        )
        self.projector = ProjectionHead(encoder_dim)

    def forward(self, x):
        feat = self.encoder(x) # [B, encoder_dim]
        proj = self.projector(feat) # [B, out_dim from ProjectionHead]
        return F.normalize(proj, dim=1) # L2 normalize for cosine similarity
    

For each sample, its representation is L2 normalized so its norm becomes 1. Normalization is important because the contrastive loss uses cosine similarity, which simplifies computation and ensures all vectors lie on the unit hypersphere. This prevents the model from trivially increasing vector magnitudes instead of learning meaningful directions.

The SimCLR loss function, **NT-Xent (Normalized Temperature-scaled Cross-Entropy Loss)**, is at the heart of SimCLR. The temperature parameter is a critical hyperparameter that affects training dynamics; lower values make the softmax distribution sharper, leading to harder contrasts.

Given two representations, $z_1$ and $z_2$, from augmented views of the same batch (concatenated to shape [2B, D]), we compute pairwise similarities. The sim matrix, shaped [2B, 2B], contains the similarity of each sample with every other sample. When $i=j$, it's the self-similarity. For contrastive learning, we want to compare a sample only to its positive pair, *excluding* self-similarity (which would be trivial). An identity matrix mask is used to set self-similarity scores to a very small negative number (e.g., `-9e15`) so that they are effectively zeroed out by softmax. sim_targets define which indices correspond to the positive pairs within the [2B, 2B] similarity matrix for the cross-entropy loss. The NT-Xent loss encourages $z_1$ to be close to $z_2$ (its positive pair) and far from all other negative samples in the batch.


import torch.nn.functional as F

# ------------------------
# SimCLR Loss (NT-Xent)
# ------------------------
def nt_xent_loss(z1, z2, temperature=0.5):
    B = z1.size(0) # Batch size
    z = torch.cat([z1, z2], dim=0) # Concatenate both views: [2B, D]

    # Compute pairwise cosine similarity: [2B, 2B]
    # z.unsqueeze(1) -> [2B, 1, D]
    # z.unsqueeze(0) -> [1, 2B, D]
    # Result of cosine_similarity -> [2B, 2B]
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    sim /= temperature # Scale similarities by temperature

    # Create labels for positive pairs
    # For z1 samples (0 to B-1), their positive pair is at B to 2B-1
    # For z2 samples (B to 2B-1), their positive pair is at 0 to B-1
    # E.g., if B=2: labels = [2, 3, 0, 1]
    labels = torch.arange(B, device=z.device) # [0, 1, ..., B-1]
    sim_targets = torch.cat([labels + B, labels], dim=0) # [B, B+1, ..., 2B-1, 0, 1, ..., B-1]

    # Create a mask to remove self-similarity (diagonal elements)
    mask = torch.eye(2 * B, device=z.device).bool()
    sim.masked_fill_(mask, -9e15) # Set self-similarity to a very small number

    # Compute cross-entropy loss. `sim` are the logits, `sim_targets` are the true classes.
    # The loss for each row `i` tries to classify `sim_targets[i]` as the correct positive.
    loss = F.cross_entropy(sim, sim_targets)
    return loss
    

And here is the training loop:

Batch size is crucial for SimCLR; larger batch sizes provide more negative examples, which generally improves performance. SimCLR often requires large batch sizes (e.g., 512 or more). Tricks like memory banks (used in MoCo) can help in storing a large number of negative samples efficiently.

Momentum Contrast (MoCo)

MoCo (Momentum Contrast) was designed specifically to address the large batch size requirement of SimCLR. In NT-Xent loss, each sample is contrasted with every other sample in the current batch. While more negatives lead to a stronger learning signal, this makes large batch sizes impractical on single GPUs. MoCo decouples the batch size from the number of negative samples by maintaining a queue of past encoded samples, providing a large and dynamic set of negative examples. To ensure consistency of the embeddings in this queue, MoCo uses a **momentum encoder** for the "key" side of the contrastive pair:

$$ \theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q $$

There are two encoders: a **query encoder** ($\theta_q$) and a **key encoder** ($\theta_k$). Parameters of the query encoder are updated normally via gradient descent, while parameters of the key encoder are updated slowly using an exponential moving average (EMA) of the query encoder's parameters. The queue holds key representations from this slowly changing key encoder to stabilize training. It acts like a FIFO (First-In, First-Out) buffer, storing old key embeddings and overwriting the oldest entries.

For a given input, an augmented view x_q is encoded by the query encoder, and another augmented view x_k is encoded by the momentum-updated key encoder (with gradients stopped for x_k's path).


import copy # For deepcopying the encoder

class MoCo(nn.Module):
    def __init__(self, encoder, feature_dim=128, queue_size=1024, momentum=0.999):
        super().__init__()
        self.query_encoder = encoder # The encoder that gets updated by gradients
        self.key_encoder = copy.deepcopy(encoder) # A copy for the key encoder

        # Freeze key encoder parameters
        for param in self.key_encoder.parameters():
            param.requires_grad = False 

        # Register buffer for the queue of negative samples
        self.register_buffer("queue", torch.randn(queue_size, feature_dim))
        self.queue = F.normalize(self.queue, dim=1) # Normalize queue contents
        
        # Pointer for the circular queue
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        self.momentum = momentum
        self.queue_size = queue_size
        self.feature_dim = feature_dim

    @torch.no_grad() # This update happens without gradient tracking
    def _momentum_update_key_encoder(self):
        """Momentum update of the key encoder"""
        for param_q, param_k in zip(self.query_encoder.parameters(), self.key_encoder.parameters()):
            param_k.data = self.momentum * param_k.data + (1 - self.momentum) * param_q.data

    @torch.no_grad() # Queue operations do not require gradients
    def _dequeue_and_enqueue(self, keys):
        keys = keys.detach() # Detach keys from the computation graph
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr[0])
        
        # Replace the oldest entries with current keys
        # If batch_size + ptr > queue_size, it will wrap around
        if ptr + batch_size > self.queue_size:
            # Handle wrap-around
            overflow = (ptr + batch_size) - self.queue_size
            self.queue[ptr:] = keys[:self.queue_size - ptr]
            self.queue[:overflow] = keys[self.queue_size - ptr:]
        else:
            self.queue[ptr:ptr + batch_size] = keys

        ptr = (ptr + batch_size) % self.queue_size # Update pointer
        self.queue_ptr[0] = ptr

    def forward(self, x_q, x_k):
        # Step 1: Encode
        q = self.query_encoder(x_q) # Query embedding: [B, D]
        q = F.normalize(q, dim=1) # Normalize query

        with torch.no_grad(): # No gradient for key encoder path
            self._momentum_update_key_encoder() # Update key encoder
            k = self.key_encoder(x_k) # Key embedding: [B, D]
            k = F.normalize(k, dim=1) # Normalize key

        # Step 2: Compute logits
        # Positives: dot product of query with its corresponding key
        pos = torch.sum(q * k, dim=1, keepdim=True) # [B, 1]

        # Negatives: dot product of query with all entries in the queue
        neg = torch.matmul(q, self.queue.clone().detach().T) # [B, K] - clone() and detach() to ensure it's not part of graph

        # Concatenate positive and negative logits
        logits = torch.cat([pos, neg], dim=1) # [B, 1 + K]

        # Labels for cross-entropy: first column (positives) is the correct class
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        # Step 3: Update queue with current batch's keys
        self._dequeue_and_enqueue(k)

        return logits, labels
    

The _momentum_update_key_encoder function smoothly updates the key encoder using the query encoder's weights. This is crucial because if old keys in the queue were produced by a completely static encoder, they would become inconsistent. MoCo's slow update allows the key encoder to track the query encoder, maintaining consistency for the queue embeddings.

The core contrastive step involves taking the dot product of matching q and k (positive pairs) and the dot product of q with all entries in the queue (negatives). These are concatenated to form the logits for the contrastive loss. The queue is then updated by storing the current batch's k embeddings and removing the oldest entries, acting as a circular buffer (_dequeue_and_enqueue).

The buffer (initialized with self.register_buffer("queue", ...) and self.register_buffer("queue_ptr", ...)) is what allows MoCo to scale contrastive learning without needing huge batch sizes. It is tracked by PyTorch as part of the model but does not receive gradients. The _dequeue_and_enqueue method acts like a rolling memory of negative samples, where key embeddings are generated by the momentum-updated encoder. The negatives are not from the current batch but from past batches stored in this buffer.

In essence:

In the training loop, after obtaining x_q and x_k (two augmented views of the input):

Here, q = query_encoder(x_q) is used for gradient-based learning, while k = key_encoder(x_k) is generated under torch.no_grad() using momentum-updated weights. The logits are formed from [q·k (positives), q·queue (negatives)]. The _momentum_update_key_encoder function ensures the key encoder weights $\theta_k$ smoothly track $\theta_q$, keeping the queue embeddings stable across training steps. This consistency is vital for the negative queue to work effectively with small batches and delayed negatives.

To summarize, MoCo enables contrastive learning without relying on large batch sizes by maintaining a large queue of past key embeddings. These keys are encoded using a momentum-updated encoder, which ensures their consistency over time.

CPC (Contrastive Predictive Coding)

CPC (Contrastive Predictive Coding) learns representations by dividing signals into segments and using early segments to predict future ones in a contrastive manner. Given a sequence, it first encodes it into latent vectors. Then, a context encoder (often an autoregressive model like a GRU) summarizes the past. Instead of directly reconstructing or regressing to future values, CPC uses a contrastive loss to make the predicted future representation similar to the true future representation while making it dissimilar to negative samples (other possible future representations from the batch). CPC tries to predict the future in a latent space; it doesn't reconstruct the input. The core idea: "Let me summarize the past, then guess what the future will look like in latent space. I'm correct if my prediction is more similar to the true future than to all other possible futures."

Here is CPC step by step:

  1. Raw input is a time series, $x = [x_1, x_2, \dots, x_T]$ with shape [B, C, T].
  2. Pass it through an encoder (often a stack of Conv1d layers) to get latent embeddings $z = \text{encoder}(x)$, with shape [B, T', D].
  3. Pass latent embeddings through a context encoder (e.g., GRU or Transformer) to get $c_t$, a summary of the past up to time $t$, also with shape [B, T', D]. $c_t$ summarizes $z_{1:t}$.
  4. At each timestep $t$, you try to predict future embeddings. Predict $k$ future steps from each $c_t$. Each future step $k$ can have a learnable linear layer $W_k$ that outputs a predicted future latent $\hat{z}^{t+k}$ for $z_{t+k}$.
  5. Given context $c_t$, the goal is to maximize the similarity between the prediction $\hat{z}^{t+k}$ and the true future latent $z_{t+k}$, while minimizing similarity with negative samples (all other latents $z_j$ in the batch that are not $z_{t+k}$). This is done using the **InfoNCE loss**:

$$ \text{loss} = -\log\left(\frac{\exp(\text{sim}(c_t, z_{t+k}))}{\sum_j \exp(\text{sim}(c_t, z_j))}\right) $$

The InfoNCE loss maximizes $\text{sim}(c_t, z_{t+1})$ and minimizes $\text{sim}(c_t, z_j)$ for negative samples $z_j$. Negative samples are typically other latent vectors from the batch that do not correspond to the true future target. CPC learns representations that understand temporal structure.

Here’s an example of the CPC encoder:


import torch.nn as nn

class CPCEncoder(nn.Module):
    def __init__(self, in_channels=1, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels, 64, kernel_size=10, stride=5, padding=3), # Downsample
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=8, stride=4, padding=2), # Downsample more
            nn.ReLU(),
            nn.Conv1d(128, latent_dim, kernel_size=4, stride=2, padding=1), # Final downsample to latent_dim
            nn.ReLU(),
        )

    def forward(self, x):
        """
        x: [B, C, T] - Input time series
        output: [B, T', D] - Latent embeddings, T' is reduced time length, D is latent_dim
        """
        z = self.encoder(x) # [B, D_latent, T'] (Conv1d outputs [B, out_channels, out_length])
        z = z.permute(0, 2, 1) # [B, T', D_latent] - Permute to (Batch, SequenceLength, FeatureDim) for GRU
        return z
    

And here’s the CPC context encoder. C_t (e.g., c[:, t, :]) summarizes everything prior to time $t$.


class CPCContext(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=128):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) # GRU processes sequence

    def forward(self, z):
        """
        z: [B, T', D] - Latent embeddings from CPCEncoder
        c: [B, T', D] - Contextualized embeddings at each time step
        """
        c, _ = self.gru(z) # c: context at each time step from GRU
        return c
    

We feed the context vector at time $t$ to the prediction head to predict the latent embeddings for future time steps. Specifically, we predict $\hat{z}^{t+k} = W_k(c_t)$, where $W_k$ is a learnable linear layer for each future step $k$. We take the first $T'-k$ time steps from the context (where $T'$ is the sequence length of the latent embeddings), feed them into the linear layer $W_k$, and output the predicted future latent $\hat{z}^{t+k}$. Each future offset has its own predictor. For example, $W_1$ predicts $z_{t+1}$ from $c_t$, and $W_2$ predicts $z_{t+2}$ from $c_t$, etc. Using separate $W_k$ for each $k$ allows specialized learning: $W_1$ might specialize in short-term predictions, while $W_5$ can learn coarser, long-range structures, as the relationship between $c_t$ and $z_{t+1}$ is not necessarily the same as between $c_t$ and $z_{t+5}$.


class CPCPredictor(nn.Module):
    def __init__(self, latent_dim=128, k_steps=3):
        super().__init__()
        self.k_steps = k_steps
        self.predictors = nn.ModuleList([
            nn.Linear(latent_dim, latent_dim) for _ in range(k_steps) # A linear layer for each future step k
        ])

    def forward(self, context):
        """
        context: [B, T', D] (output of GRU context encoder)
        returns: list of predictions:
            preds[k-1]: [B, T' - k, D], prediction of z_{t+k} from c_t (using context up to T'-k)
        """
        B, T_prime, D = context.shape
        preds = []

        for k, predictor in enumerate(self.predictors, start=1):
            # For predicting z_{t+k}, we use context up to T' - k
            # e.g., if k=1, use context up to T'-1 to predict z_T'
            # if k=2, use context up to T'-2 to predict z_{T'-1}, etc.
            pred = predictor(context[:, :T_prime - k, :])
            preds.append(pred)

        return preds # list of length k_steps, each item is [B, T'-k, D]
    

So far, we've encoded signals, used a context encoder to summarize the past, and separate linear layers to predict the future. Next, we use the InfoNCE loss between predicted and actual future embeddings. At each time step $t$ and for each future time step $k$, we want to score the true future latent $z_{t+k}$ as similar to the predicted $\hat{z}^{t+k}$ from context $c_t$, while scoring all other latents as negative. This maximizes similarity with the true future while minimizing similarity with negatives.

The targets passed to InfoNCE are all positive latent vectors (the true future embeddings we want to predict). Negatives are implicitly included within the similarity matrix. We compute an $N \times N$ similarity matrix between every predicted embedding and every target embedding in the batch. The diagonal entries correspond to positive pairs, and off-diagonal entries correspond to negative pairs. The cross-entropy loss treats this as a classification task, where each prediction should select its positive target from all candidates. Labels simply tell the cross-entropy for each row which column is positive. This inherently encourages diagonal elements to be the highest in that row and off-diagonal elements to be pushed down. Instead of a binary classification, the loss is a multi-class classification problem.


def info_nce_loss(preds, targets, temperature=0.07):
    """
    Computes InfoNCE loss.
    preds: [N, D] predicted latents (flattened batch and time for all predicted steps/batches)
    targets: [N, D] true latents (flattened batch and time for all corresponding true future steps/batches)
    """
    # Normalize embeddings to unit vectors for cosine similarity
    preds_norm = F.normalize(preds, dim=1)    # [N, D]
    targets_norm = F.normalize(targets, dim=1) # [N, D]

    # Compute similarity matrix: (N_preds, N_targets)
    # The diagonal elements are positive pairs, off-diagonals are negatives
    similarity_matrix = torch.matmul(preds_norm, targets_norm.T) / temperature

    # Labels for cross-entropy are the indices of the positive targets
    # E.g., for preds[i], its positive target is targets[i]. So label is i.
    labels = torch.arange(preds.shape[0], device=preds.device)

    # Cross-entropy loss: maximizes the log-likelihood of correctly classifying
    # the true positive target among all other targets in the batch.
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss
    

CPC is known to be robust to noise and irrelevant signal parts. After pre-training in a self-supervised fashion, the learned representations can be applied to downstream tasks like classification, regression, and clustering. Typically, the encoder is frozen, and a new classification head is trained on top for the specific downstream task.


BYOL (Bootstrap Your Own Latent)

Before discussing BYOL, let's clarify some terminology:

BYOL (Bootstrap Your Own Latent) is a self-supervised learning method that learns from positive pairs only, without requiring negative samples or explicit labels. The model effectively creates its own training signal. It uses two networks:

  1. An **online network**: This network is trainable via standard gradient descent.
  2. A **target network**: This network is a slowly moving average (Exponential Moving Average, EMA) of the online network's parameters and is kept frozen during the forward pass.

The online network is trained to predict the output representation of the target network. In essence, the model improves itself by predicting its own slowly updated version. It learns to map different augmented views of the same data to the same latent space without external supervision. This is the "bootstrapping" aspect: the model pulls itself up by comparing itself to its own evolving past outputs. It's like teaching a model to predict its own representation of a different view of the same input, using a stable, slowly updated target network as the prediction target.

The task is to predict the target projection of one augmented view from the online projection of another augmented view. Two augmented views of the same data ($x_1$ and $x_2$) are generated. Both go through their respective encoders. The online path includes an encoder, a projector, and a predictor. The target path includes an encoder and a projector (no predictor). Cosine similarity is computed between the online prediction and the target representation. Backpropagation occurs through the online network, and the target network's weights are updated using EMA.

Let's assume we have a base encoder. We'll wrap it with projector and predictor MLPs. These MLPs output embeddings of out_dim size.


class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), # BatchNorm can help stabilize training
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)
    

The online and target networks are initially identical. The _update_target_network method performs the EMA update, smoothly tracking the online network's parameters to keep the target network stable.


class BYOL(nn.Module):
    def __init__(self, base_encoder_fn, in_dim, hidden_dim=4096, out_dim=256, ema_decay=0.99):
        super().__init__()
        
        # Online network: trainable via gradients
        self.online_encoder = base_encoder_fn() # base_encoder_fn should be a callable that returns an encoder instance
        self.online_projector = MLPHead(in_dim, hidden_dim, out_dim)
        self.online_predictor = MLPHead(out_dim, hidden_dim, out_dim) # Predictor adds asymmetry

        # Target network: parameters updated via EMA, kept frozen during forward pass
        self.target_encoder = base_encoder_fn()
        self.target_projector = MLPHead(in_dim, hidden_dim, out_dim)

        # Initialize target network parameters to match online network
        self._update_target_network(ema=0) # ema=0 means direct copy
        
        self.ema_decay = ema_decay

    @torch.no_grad() # Ensure no gradients are computed for target network update
    def _update_target_network(self, ema=None):
        """EMA update for target network parameters based on online network parameters."""
        for online_param, target_param in zip(
            self.online_encoder.parameters(), self.target_encoder.parameters()):
            target_param.data = (
                ema * target_param.data + (1 - ema) * online_param.data
                if ema is not None else online_param.data.clone()
            )
        
        for online_param, target_param in zip(
            self.online_projector.parameters(), self.target_projector.parameters()):
            target_param.data = (
                ema * target_param.data + (1 - ema) * online_param.data
                if ema is not None else online_param.data.clone()
            )
    
    # The forward pass of BYOL model
    def forward(self, x1, x2):
        # Online network forward pass
        # x1 -> encoder -> projector -> predictor -> o1
        o1 = self.online_predictor(self.online_projector(self.online_encoder(x1)))
        # x2 -> encoder -> projector -> predictor -> o2
        o2 = self.online_predictor(self.online_projector(self.online_encoder(x2)))
        
        # Target network forward pass (gradients are stopped)
        with torch.no_grad():
            # Apply EMA update before getting target embeddings for current step
            # This ensures target network is slightly behind online network
            self._update_target_network(ema=self.ema_decay) # Update target network here

            # x1 -> target_encoder -> target_projector -> t1
            t1 = self.target_projector(self.target_encoder(x1))
            # x2 -> target_encoder -> target_projector -> t2
            t2 = self.target_projector(self.target_encoder(x2))

        # Normalize outputs to unit vectors for cosine similarity loss
        o1 = F.normalize(o1, dim=-1)
        o2 = F.normalize(o2, dim=-1)
        t1 = F.normalize(t1, dim=-1)
        t2 = F.normalize(t2, dim=-1)

        # Symmetric loss function
        # Loss is sum of (1 - cosine_similarity(o1, t2)) and (1 - cosine_similarity(o2, t1))
        # Where t1 and t2 are detached (no gradients flow to target network)
        loss = 2 - 2 * (
            (o1 * t2.detach()).sum(dim=-1).mean() + # Dot product = Cosine similarity for normalized vectors
            (o2 * t1.detach()).sum(dim=-1).mean()
        ) / 2 # Average over the two terms
        return loss
    

Two augmented views of the same input are forwarded through the network. The projector maps the encoder output to a latent space, and the predictor matches this projection to the target dimension. The reason for having a predictor is to introduce asymmetry between the online and target encoders. This asymmetry is crucial, as without it, the model can collapse to a trivial constant output (mode collapse). The predictor gives the online network the flexibility to learn how to align its representations with the target network's stable representations. Specifically, $x_1$ and $x_2$ go through the online network (encoder $\rightarrow$ projector $\rightarrow$ predictor) to produce $o_1$ and $o_2$. Separately, $x_1$ and $x_2$ go through the target network (encoder $\rightarrow$ projector) to produce $t_1$ and $t_2$, with gradients stopped for this path. BYOL uses cosine similarity, so all outputs are normalized to unit vectors. To prevent shortcut learning and ensure the model learns view-invariant features, the loss matches $o_1$ with $t_2$ and $o_2$ with $t_1$ (cross-view prediction).

The training loop would involve taking a batch $X$, augmenting it into $x_1$ and $x_2$, computing the BYOL loss between the online and target views, backpropagating through the online model, and then updating the target network's weights using EMA.