Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

Fine-tuning transfers the representations learned during self-supervised pretraining into a downstream cardiovascular disease classifier. A linear classification head is attached to the frozen or unfrozen ECGEncoder1DCNN encoder, then trained on the labeled fraction of PTB-XL. With only 10% of labels (1,747 training samples) and focal loss combined with oversampling, the SimCLR encoder achieves an AUROC of 0.8717 and a macro-F1 of 0.6448 — a +12.15% F1 improvement over the supervised CNN baseline.

How It Works

1

Load the pretrained encoder

The .pt checkpoint produced by train_ssl_simclr.py or train_ssl_byol.py is loaded and its "encoder" state dict is restored into ECGEncoder1DCNN(in_ch=12, width=64).
2

Attach a classification head

ECGClassifier wraps the encoder with a linear layer that maps the 256-dim latent to n_classes=5 logits (one per PTB-XL diagnostic superclass: NORM, MI, STTC, HYP, CD).
3

Sample labeled data

sample_labelled_indices stratifies the PTB-XL train split (folds 1–8) and returns the fraction specified by --label-fraction (default 10%).
4

Balance training data

The data loader applies the strategy from --balance-strategy (oversample by default) to compensate for the 3.32× class imbalance between NORM and HYP.
5

Train with imbalance-aware loss

FocalLoss(alpha=0.25, gamma=2.0) (default) down-weights easy negatives so the model focuses on hard positive examples from minority classes.
6

Save the best checkpoint

The epoch with the highest validation macro-F1 is saved to --out as {"model": <state_dict>}.

Training Commands

python -m ssrl_ecg.train_finetune \
  --data-root data/PTB-XL \
  --ssl-checkpoint checkpoints/ssl_simclr.pt \
  --epochs 20 \
  --batch-size 64 \
  --lr 1e-3 \
  --label-fraction 0.1 \
  --signal-length 1000 \
  --loss focal \
  --balance-strategy oversample \
  --seed 42 \
  --out checkpoints/finetuned_simclr.pt
Result: F1=0.6448, AUROC=0.8717To use linear probing (freeze the encoder and train only the classification head), add the --freeze-encoder flag:
python -m ssrl_ecg.train_finetune \
  --ssl-checkpoint checkpoints/ssl_simclr.pt \
  --freeze-encoder \
  --epochs 20 \
  --out checkpoints/finetuned_simclr_frozen.pt

CLI Arguments

--data-root
Path
default:"data/PTB-XL"
Root directory of the PTB-XL dataset. Must contain ptbxl_database.csv, scp_statements.csv, and a records100/ folder with .hea/.dat files.
--ssl-checkpoint
Path
required
Path to the pretrained SSL encoder checkpoint (.pt file produced by train_ssl_simclr.py or train_ssl_byol.py). This argument is required — there is no default. The checkpoint must contain an "encoder" key.
--epochs
int
default:"20"
Number of fine-tuning epochs. The best model is tracked by validation macro-F1 and saved at the end, so additional epochs only help if the model keeps improving.
--batch-size
int
default:"64"
Training batch size. Smaller than the SSL pretraining default because labeled data is scarce and the classification head benefits from careful gradient updates.
--lr
float
default:"1e-3"
Learning rate for the Adam optimizer. Only parameters with requires_grad=True are updated — if --freeze-encoder is set, only the classification head is optimized.
--label-fraction
float
default:"0.1"
Fraction of the PTB-XL training split to use as labeled data. 0.1 corresponds to approximately 1,747 samples out of 17,489. Valid range: (0, 1].
--signal-length
int
default:"1000"
Number of time steps per ECG record. At 100 Hz this equals 10 seconds. Must match the value used during SSL pretraining.
--freeze-encoder
flag
When set, all encoder parameters have requires_grad set to False before training. Only the linear classification head is updated. This is true linear probing and is faster, but full fine-tuning (default, no flag) typically achieves higher performance.
--seed
int
default:"42"
Global random seed passed to set_seed() for reproducible sampling, weight initialization, and data loading order.
--out
Path
default:"checkpoints/finetuned.pt"
Path where the best checkpoint is saved. The file stores {"model": <state_dict>}. Parent directories are created automatically.
--loss
str
default:"focal"
Loss function for the multi-label classification objective. Choices:
  • focal — Focal Loss with alpha=0.25, gamma=2.0. Recommended.
  • bce — Standard BCEWithLogitsLoss. Treats all classes equally.
  • weightedWeightedBCELoss with per-class weights computed by inverse frequency.
  • class_balancedClassBalancedLoss with beta=0.9999 based on effective number of samples.
--balance-strategy
str
default:"oversample"
Data loader balancing strategy to address the 3.32× class imbalance in PTB-XL. Choices:
  • oversample — Duplicate minority-class samples to equalize class frequency. Recommended.
  • stratified — Stratified sampling ensures each batch reflects the true class distribution.
  • standard — No rebalancing; samples are drawn in natural order.

Loss Functions

Per-class weights are derived from inverse class frequency. A class appearing in 10% of samples gets 10× the weight of one appearing in 100% of samples. Useful when you want simple, interpretable class weighting.
from ssrl_ecg.models.losses import WeightedBCELoss, compute_class_weights
weights = compute_class_weights(train_labels, method="inverse_frequency").to(device)
criterion = WeightedBCELoss(class_weights=weights, reduction="mean")
Uses the effective number of samples E_n = (1 - beta^n) / (1 - beta) to compute per-class weights, following Cui et al. (2019). beta=0.9999 is a soft smoothing factor.
from ssrl_ecg.models.losses import ClassBalancedLoss
class_counts = (train_labels == 1).sum(dim=0)
criterion = ClassBalancedLoss(class_counts, beta=0.9999, reduction="mean")
torch.nn.BCEWithLogitsLoss with no class weighting. Use as a sanity-check baseline when you want to isolate the effect of loss choice, but expect lower performance on minority classes.

Data Balancing Strategies

The PTB-XL training set has a 3.32× imbalance ratio (NORM: 9,514 vs HYP: 2,649). The --balance-strategy argument controls how create_balanced_dataloader addresses this:
StrategyBehaviorUse Case
oversampleMinority classes are duplicated until frequencies are equalizedRecommended — best empirical performance
stratifiedBatches are assembled to reflect original class proportionsWhen you want unbiased gradient estimates
standardStandard PyTorch DataLoader shuffle; no resamplingBaseline comparison
Do not apply oversample to the validation or test sets. The val/test loaders always use standard ordering regardless of --balance-strategy.

Results

EncoderLossBalanceAUROCF1 MacroSensitivitySpecificity
SimCLRfocaloversample0.87170.64480.68310.9411
BYOLfocaloversample0.85650.63010.66480.9278
SimCLR fine-tuning outperforms the supervised CNN baseline (AUROC=0.8606, F1=0.5750) by +12.15% F1 using the same labeled data fraction.

Loading a Saved Checkpoint

After fine-tuning, the best model is saved as {"model": <state_dict>}. To reload it for evaluation or inference:
import torch
from ssrl_ecg.models.cnn import ECGClassifier, ECGEncoder1DCNN

# Reconstruct the model
encoder = ECGEncoder1DCNN(in_ch=12, width=64)
model = ECGClassifier(encoder=encoder, n_classes=5)

# Load the fine-tuned weights
ckpt = torch.load("checkpoints/finetuned.pt", map_location="cpu")
model.load_state_dict(ckpt["model"])
model.eval()

# Run inference
with torch.no_grad():
    logits = model(ecg_tensor)          # shape: (batch, 5)
    probs = torch.sigmoid(logits)       # multi-label probabilities
Always pass map_location="cpu" when loading on a machine that may not have the same GPU configuration as the training machine. Move to GPU afterwards with model.to(device).

Next Steps

SSL Pretraining

Generate the SSL encoder checkpoint that feeds into this fine-tuning pipeline.

Supervised Baseline

See how much SSL pretraining improves over training the same CNN from scratch.

Build docs developers (and LLMs) love