Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

Traditional machine learning offers a fast, interpretable starting point for ECG classification before committing to deep learning infrastructure. SSRL-ECG includes train_traditional_ml.py, which extracts 24+ handcrafted per-lead statistical and spectral features from each 12-lead ECG recording and fits a multi-label Random Forest or XGBoost classifier on a labeled subset of PTB-XL. No GPU is required, training typically completes in minutes, and feature importances are directly inspectable.
Traditional ML is most useful for rapid prototyping, interpretability audits, and low-compute environments. It establishes a feature-engineering baseline that highlights how much representation learning (supervised CNN or SSL) adds on top of handcrafted features.

Feature Extraction

Each 12-lead ECG is converted to a fixed-length feature vector by extract_ecg_features(). Features are computed per-lead and collected into a flat dictionary that becomes a row in a pandas DataFrame.

Per-Lead Statistics (5 features, computed on lead 0)

FeatureFormula
meannp.mean(lead)
stdnp.std(lead)
maxnp.max(lead)
minnp.min(lead)
rangenp.ptp(lead) (peak-to-peak)

Gradient Features (2 features)

Computed from the first-order difference of lead 0, serving as a rough QRS proxy:
FeatureFormula
mean_gradientnp.mean(np.abs(np.diff(lead)))
max_gradientnp.max(np.abs(np.diff(lead)))

Energy and Entropy (2 features)

FeatureFormula
energynp.sum(lead²)
entropy−Σ p × log(p) where p = lead² / Σ(lead² + ε)

Zero-Crossing Rate (1 feature)

Zero crossings relative to the signal mean act as a simple heart rate proxy:
FeatureFormula
zero_crossingsCount of sign changes in lead − mean(lead)

Per-Lead RMS (12 features)

For multi-lead inputs, RMS energy is computed independently for each of the 12 leads:
FeatureFormula
rms_lead_0rms_lead_11√(mean(lead_i²)) for i ∈ [0, 11]
Total feature vector size: 5 + 2 + 2 + 1 + 12 = 22 features (plus the single rms fallback for 1-lead inputs). All features are standardized with StandardScaler (zero mean, unit variance) before being passed to the classifier.

Classifiers

Random Forest (--model rf) fits an ensemble of 100 decision trees with maximum depth 15. It uses all available CPU cores (n_jobs=-1) and supports multi-output classification natively via RandomForestClassifier.
python -m ssrl_ecg.train_traditional_ml \
  --data-root data/PTB-XL \
  --label-fraction 0.1 \
  --model rf \
  --seed 42
Internal configuration:
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(
    n_estimators=100,
    max_depth=15,
    random_state=args.seed,
    n_jobs=-1,
)
Random Forest is the default (--model rf) and requires no additional dependencies beyond scikit-learn.

CLI Arguments

--data-root
Path
default:"data/PTB-XL"
Root directory of the PTB-XL dataset. Must contain ptbxl_database.csv, scp_statements.csv, and records100/ with .hea/.dat ECG records.
--label-fraction
float
default:"0.1"
Fraction of the PTB-XL training split (folds 1–8) to use for fitting the classifier. 0.1 corresponds to approximately 1,747 samples, matching the labeled data budget used by the deep learning baselines for a fair comparison.
--seed
int
default:"42"
Random seed passed to np.random.seed() and to the classifier’s random_state. Controls which samples are drawn by sample_labelled_indices and the internal tree-building randomness.
--model
str
default:"rf"
Which classifier to train. Choices:
  • rfRandomForestClassifier(n_estimators=100, max_depth=15). No extra dependencies.
  • xgbXGBClassifier(n_estimators=100, max_depth=5, lr=0.1). Requires xgboost package; falls back to RF if not available.

Training Pipeline

1

Load metadata and splits

load_ptbxl_metadata reads the PTB-XL CSV files. make_default_splits partitions into train (folds 1–8), val (fold 9), and test (fold 10) using the standard PTB-XL protocol.
2

Sample labeled training indices

sample_labelled_indices stratifies the training fold and returns label_fraction × N indices, matching the low-data regime used by CNN and SSL baselines.
3

Extract features for all splits

Each ECG signal is loaded via PTBXLRecordDataset, converted to a NumPy array, and passed to extract_ecg_features(). The result is a pandas DataFrame with one row per sample. Missing values are filled with 0.
4

Standardize features

StandardScaler is fit on the training set only. The same scaler transforms validation and test sets to prevent data leakage.
5

Fit the classifier

model.fit(X_train, y_train) where y_train is the multi-label binary matrix of shape (n_samples, 5).
6

Evaluate on validation and test sets

Per-class AUROC and macro-F1 are printed for both splits. No checkpoint is saved — re-run the script to reproduce results.

Feature Extraction Code

The full extract_ecg_features function from train_traditional_ml.py:
import numpy as np

def extract_ecg_features(signal: np.ndarray, sampling_rate: int = 100) -> dict:
    """Extract handcrafted ECG features from a single 12-lead record."""
    features = {}

    # Per-lead statistics (lead 0)
    lead = signal[0] if signal.ndim > 1 else signal
    features["mean"]  = float(np.mean(lead))
    features["std"]   = float(np.std(lead))
    features["max"]   = float(np.max(lead))
    features["min"]   = float(np.min(lead))
    features["range"] = float(np.ptp(lead))

    # Gradient features (QRS proxy)
    diff = np.diff(lead)
    features["mean_gradient"] = float(np.mean(np.abs(diff)))
    features["max_gradient"]  = float(np.max(np.abs(diff)))

    # Energy and entropy
    features["energy"]  = float(np.sum(lead ** 2))
    features["entropy"] = float(
        -np.sum(
            (lead**2 / np.sum(lead**2 + 1e-10))
            * np.log(lead**2 / np.sum(lead**2 + 1e-10) + 1e-10)
        )
    )

    # Zero-crossing rate (heart rate proxy)
    zero_crossings = np.sum(np.diff(np.sign(lead - np.mean(lead))) != 0)
    features["zero_crossings"] = float(zero_crossings)

    # Per-lead RMS (all 12 leads)
    if signal.ndim > 1:
        for i in range(min(12, signal.shape[0])):
            features[f"rms_lead_{i}"] = float(np.sqrt(np.mean(signal[i] ** 2)))
    else:
        features["rms"] = float(np.sqrt(np.mean(lead ** 2)))

    return features

When to Use Traditional ML

Fast Prototyping

No GPU required. Training and evaluation complete in minutes even on a laptop. Ideal for quickly validating dataset loading, preprocessing, and label quality before investing in deep learning training runs.

Interpretability

Random Forest feature importances reveal which statistical properties (energy, zero crossings, per-lead RMS) are most discriminative for each disease class. Useful for clinical hypothesis generation.

Low-Resource Deployment

The fitted scikit-learn model can be serialized with joblib and deployed on edge devices or servers without PyTorch or GPU drivers.

Baseline Benchmark

Establishes a feature-engineering floor. The gap between traditional ML and the SSL-pretrained CNN quantifies how much end-to-end representation learning adds beyond hand-designed features.
Traditional ML baselines do not use the Focal Loss or oversampling strategies available in the deep learning pipeline. Class imbalance affects RF and XGBoost differently — consider setting class_weight="balanced" in RandomForestClassifier for fairer comparison on minority classes.

Next Steps

Supervised CNN Baseline

Train the CNN from scratch to see the gain from end-to-end deep feature learning over handcrafted features.

SSL Pretraining

Pretrain with SimCLR or BYOL to push performance beyond both traditional ML and supervised CNN baselines.

Build docs developers (and LLMs) love