Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

Self-supervised pretraining lets you extract rich ECG representations from the full unlabeled training set before ever touching a diagnostic label. SSRL-ECG supports two complementary SSL frameworks — SimCLR and BYOL — both built on the same ECGEncoder1DCNN backbone and the same suite of seven domain-adaptive augmentations. The resulting encoder checkpoint is a drop-in input for the fine-tuning pipeline.

Architecture Overview

Both frameworks share the same encoder backbone. The encoder is ECGEncoder1DCNN with in_ch=12 (12-lead ECG) and width=64 (three stacked Conv1D blocks expanding to a 256-dimensional latent). The two frameworks differ only in their objective and projection heads:

SimCLR

Contrastive learning with NT-Xent loss. Projects to a 128-dim space and maximizes agreement between two augmented views of the same ECG. Best performing: AUROC 0.8717.

BYOL

Momentum-based learning with online/target network pair. No negative pairs required. Projects to a 256-dim space. Alternative with AUROC 0.8565.
Both checkpoints save only the encoder state dict under the "encoder" key: torch.save({"encoder": model.encoder.state_dict()}, out_path). The projection heads are discarded after pretraining.

Domain-Adaptive Augmentations

Seven augmentations are applied stochastically during pretraining to create physiologically plausible views of each ECG:
AugmentationMechanismApplication Rate
Frequency warping±5% heart rate variation50%
Medical mixupECG-aware Beta blending40%
Bandpass filteringf_low ∈ [0.5, 1.5] Hz, f_high ∈ [40, 60] Hz30%
Segment CutMix10–30% temporal masking25%
Motion artifactsBaseline wander 0.01–0.05 mV @ 0.5–2 Hz20%
Per-channel noise0.5–2% per-channel std60%
Temporal dropout5–20% masking + interpolation30%
Together these augmentations contribute a +12.15% F1 improvement over a supervised CNN trained without them.

Training

SimCLR uses the NT-Xent (normalized temperature-scaled cross-entropy) contrastive loss. Two augmented views of each ECG are produced by SimCLRAugmentations, passed through the shared encoder and a two-layer MLP projector, then compared. The Adam optimizer trains both encoder and projector jointly.Recommended command:
python -m ssrl_ecg.train_ssl_simclr \
  --data-root data/PTB-XL \
  --epochs 20 \
  --batch-size 256 \
  --lr 1e-3 \
  --temperature 0.07 \
  --projection-dim 128 \
  --signal-length 1000 \
  --seed 42 \
  --out checkpoints/ssl_simclr.pt

CLI Arguments

--data-root
Path
default:"data/PTB-XL"
Root directory of the PTB-XL dataset. Must contain ptbxl_database.csv, scp_statements.csv, and a records100/ folder.
--epochs
int
default:"20"
Number of full passes over the unlabeled training split (PTB-XL folds 1–8, 17,489 samples).
--batch-size
int
default:"256"
Number of ECG pairs per optimization step. Larger batches provide more negatives for NT-Xent and generally improve SimCLR performance. Reduce to 64 if you encounter CUDA out-of-memory errors.
--lr
float
default:"1e-3"
Learning rate for the Adam optimizer applied to both the encoder and projection head.
--temperature
float
default:"0.07"
Temperature parameter τ in the NT-Xent loss. Lower values produce sharper distributions and stronger push on negatives. The default 0.07 matches the original SimCLR paper recommendation.
--projection-dim
int
default:"128"
Output dimensionality of the two-layer MLP projection head. The encoder produces a 256-dim latent; this head maps it down to 128 dims for the contrastive objective.
--signal-length
int
default:"1000"
Number of time steps to load per ECG record. At 100 Hz this corresponds to 10 seconds, matching the full PTB-XL recording length.
--seed
int
default:"42"
Global random seed passed to set_seed() for reproducible data loading, weight initialization, and augmentation sampling.
--out
Path
default:"checkpoints/ssl_simclr.pt"
Destination path for the saved encoder checkpoint. Parent directories are created automatically. The file stores {"encoder": <state_dict>}.

Results

MetricValue
AUROC (after fine-tuning on 10% labels)0.8717
F1 Macro (after fine-tuning)0.6448
AUROC multi-seed (10 seeds)0.8717 ± 0.0032
95% CI0.8671 – 0.8763

Checkpoint Format

Both pretraining scripts write a single .pt file containing only the encoder weights. The projection heads and (for BYOL) the target network are not saved.
# Written by both train_ssl_simclr.py and train_ssl_byol.py
torch.save({"encoder": model.encoder.state_dict()}, args.out)
To inspect a checkpoint:
import torch

ckpt = torch.load("checkpoints/ssl_simclr.pt", map_location="cpu")
print(ckpt.keys())          # dict_keys(['encoder'])
print(type(ckpt["encoder"])) # <class 'collections.OrderedDict'>

Comparing SimCLR and BYOL

SimCLR is the recommended choice. It achieves higher AUROC (0.8717 vs 0.8565) and its temperature hyperparameter is easy to tune. The NT-Xent loss benefits from large batch sizes — use --batch-size 256 or higher where GPU memory allows.
BYOL is a strong alternative when batch size is constrained (e.g., on a single GPU with limited VRAM). Because it does not rely on in-batch negatives, smaller batches do not degrade the signal the same way they do for SimCLR. The EMA momentum (hardcoded to tau=0.999) controls how slowly the target network is updated.

Next Steps

Once pretraining completes, pass the checkpoint to the fine-tuning script to train a classification head on labeled data.

Fine-Tuning

Train a linear classifier on top of the pretrained encoder using 10% of PTB-XL labels.

Supervised Baseline

Compare SSL against a CNN trained from scratch to measure the SSL gain.

Build docs developers (and LLMs) love