Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

BYOL (Bootstrap Your Own Latent) is the second primary SSL objective in SSRL-ECG. Unlike SimCLR, BYOL requires no negative samples — instead it uses an online network that predicts the representations of a momentum-updated target network to avoid representational collapse. BYOL classes are defined in ssrl_ecg/train_ssl_byol.py and are composed with the ECGEncoder1DCNN backbone from ssrl_ecg.models.cnn.

BYOLProjector

BYOLProjector is the projection MLP that maps pooled encoder representations to a latent space shared by the online and target networks. It adds a BatchNorm1d between the two linear layers, which stabilises BYOL training by normalising the intermediate activations.
Linear(in_features, hidden_dim)  →  BatchNorm1d(hidden_dim)  →  ReLU  →  Linear(hidden_dim, out_dim)

Constructor Parameters

in_features
int
required
Dimension of the encoder output after global average pooling. For ECGEncoder1DCNN with default width=64 this is 256.
hidden_dim
int
default:"2048"
Width of the hidden projection layer.
out_dim
int
default:"256"
Dimension of the projected output vector.

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
Pooled encoder features of shape [batch, in_features].
z
Tensor
Projected representation of shape [batch, out_dim].

BYOLPredictor

BYOLPredictor sits on top of the online projector and predicts the (stop-gradient) target projections. Its architecture is identical to BYOLProjector. This asymmetry — the target network has no predictor — is the key mechanism that prevents representational collapse without requiring negatives.
Linear(in_features, hidden_dim)  →  BatchNorm1d(hidden_dim)  →  ReLU  →  Linear(hidden_dim, out_dim)

Constructor Parameters

in_features
int
required
Dimension of the online projector output (projection_dim, default 256).
hidden_dim
int
default:"2048"
Width of the hidden predictor layer.
out_dim
int
default:"256"
Output dimension. Must match BYOLProjector.out_dim to allow the L2 regression loss.

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
Online projected features of shape [batch, in_features].
p
Tensor
Predicted target projection of shape [batch, out_dim].

BYOLModel

BYOLModel is the full BYOL system. It maintains four networks:
NetworkGradientRole
encoder✅ trainedOnline backbone
online_projector✅ trainedOnline projector
online_predictor✅ trainedOnline predictor
target_encoder❌ EMA onlyMomentum backbone
target_projector❌ EMA onlyMomentum projector
The target networks are initialised as copies of the online networks and then updated exclusively via exponential moving average (EMA). Their parameters are frozen for gradient computation.

Constructor Parameters

encoder
nn.Module
required
An ECGEncoder1DCNN instance used as the online backbone. A deep copy is made automatically for the target encoder, initialised with identical weights.
projection_dim
int
default:"256"
Output dimension of both BYOLProjector and BYOLPredictor.
hidden_dim
int
default:"2048"
Hidden dimension used in both the projector and predictor MLPs.

Forward

def forward(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
BYOL’s forward pass processes both augmented views through the online network and passes them (with torch.no_grad()) through the target network to produce stop-gradient targets.
x1
Tensor
required
First augmented ECG view, shape [batch, channels, time].
x2
Tensor
required
Second augmented ECG view, shape [batch, channels, time].
pred1
Tensor
Online predictor output for view 1, shape [batch, projection_dim].
p1_target
Tensor
Target projector output for view 1 (stop-gradient), shape [batch, projection_dim].
pred2
Tensor
Online predictor output for view 2, shape [batch, projection_dim].
p2_target
Tensor
Target projector output for view 2 (stop-gradient), shape [batch, projection_dim].

Momentum Encoder Update

After each optimiser step the target networks must be refreshed using EMA:
@torch.no_grad()
def update_target_network(self, tau: float = 0.999):
    for online_p, target_p in zip(self.encoder.parameters(),
                                  self.target_encoder.parameters()):
        target_p.data = tau * target_p.data + (1 - tau) * online_p.data

    for online_p, target_p in zip(self.online_projector.parameters(),
                                  self.target_projector.parameters()):
        target_p.data = tau * target_p.data + (1 - tau) * online_p.data
The momentum parameter tau controls how slowly the target network tracks the online network. Values close to 1.0 produce a very stable target and are critical for BYOL’s collapse-free behaviour.
TauEffect
0.996Faster adaptation; useful for short training runs
0.999Default; good balance for 20–100 epoch training
0.9999Very slow adaptation; suitable for large-scale pretraining
Never backpropagate through the target network. The update_target_network call is decorated with @torch.no_grad() and must be called after optimizer.step(), not before.

BYOL Loss Function

The BYOL objective is a symmetric L2 regression loss between predictions and stop-gradient targets:
def byol_loss(pred1, target1, pred2, target2):
    def regression_loss(pred, target):
        pred   = F.normalize(pred,   dim=1)
        target = F.normalize(target, dim=1)
        return 2 - 2 * (pred * target).sum(dim=1).mean()

    return regression_loss(pred1, target1) + regression_loss(pred2, target2)
Both pred and target are L2-normalised before the dot product, which makes the loss equivalent to the squared cosine distance: loss = 2 - 2 * cosine_similarity. The symmetric form uses each view as both the predictor source and the prediction target.

Full Training Loop

import torch
from ssrl_ecg.models.cnn import ECGEncoder1DCNN
from ssrl_ecg.train_ssl_byol import BYOLModel, BYOLAugmentations, byol_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── Build model ───────────────────────────────────────────────────────────────
encoder = ECGEncoder1DCNN(in_ch=12, width=64)
model   = BYOLModel(encoder, projection_dim=256, hidden_dim=2048).to(device)
aug     = BYOLAugmentations(signal_length=1000, prob=0.8)

# Only the online network is optimised
optimizer = torch.optim.Adam(
    list(model.encoder.parameters()) +
    list(model.online_projector.parameters()) +
    list(model.online_predictor.parameters()),
    lr=1e-3,
)

# ── Training step ─────────────────────────────────────────────────────────────
x = torch.randn(64, 12, 1000).to(device)   # raw ECG batch

x1, x2 = aug(x)                            # two augmented views
x1, x2 = x1.to(device), x2.to(device)

# Forward: online predictions and stop-gradient targets
pred1, p1_target, pred2, p2_target = model(x1, x2)

loss = byol_loss(pred1, p1_target, pred2, p2_target)   # scalar

optimizer.zero_grad()
loss.backward()
optimizer.step()

# EMA update of target networks — MUST happen after optimizer.step()
model.update_target_network(tau=0.999)

print(f"BYOL loss: {loss.item():.4f}")

# ── Save encoder checkpoint ───────────────────────────────────────────────────
torch.save({"encoder": model.encoder.state_dict()}, "checkpoints/ssl_byol.pt")

BYOL vs SimCLR

PropertyBYOLSimCLR
Negative samplesNot requiredRequired (large batch)
Minimum batch size~32~256
Target networkEMA momentum copyNone
LossL2 regressionNT-Xent cross-entropy
Collapse preventionPredictor asymmetry + EMAHard negatives
Typical pretraining epochs20–200200–1000
BYOL is the preferred choice when GPU memory is limited and large batches are not feasible. For best results, cosine-anneal the momentum tau from 0.996 to 1.0 over the course of training as described in the original paper.

Build docs developers (and LLMs) love