Traditional machine learning offers a fast, interpretable starting point for ECG classification before committing to deep learning infrastructure. SSRL-ECG includesDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt
Use this file to discover all available pages before exploring further.
train_traditional_ml.py, which extracts 24+ handcrafted per-lead statistical and spectral features from each 12-lead ECG recording and fits a multi-label Random Forest or XGBoost classifier on a labeled subset of PTB-XL. No GPU is required, training typically completes in minutes, and feature importances are directly inspectable.
Feature Extraction
Each 12-lead ECG is converted to a fixed-length feature vector byextract_ecg_features(). Features are computed per-lead and collected into a flat dictionary that becomes a row in a pandas DataFrame.
Per-Lead Statistics (5 features, computed on lead 0)
| Feature | Formula |
|---|---|
mean | np.mean(lead) |
std | np.std(lead) |
max | np.max(lead) |
min | np.min(lead) |
range | np.ptp(lead) (peak-to-peak) |
Gradient Features (2 features)
Computed from the first-order difference of lead 0, serving as a rough QRS proxy:| Feature | Formula |
|---|---|
mean_gradient | np.mean(np.abs(np.diff(lead))) |
max_gradient | np.max(np.abs(np.diff(lead))) |
Energy and Entropy (2 features)
| Feature | Formula |
|---|---|
energy | np.sum(lead²) |
entropy | −Σ p × log(p) where p = lead² / Σ(lead² + ε) |
Zero-Crossing Rate (1 feature)
Zero crossings relative to the signal mean act as a simple heart rate proxy:| Feature | Formula |
|---|---|
zero_crossings | Count of sign changes in lead − mean(lead) |
Per-Lead RMS (12 features)
For multi-lead inputs, RMS energy is computed independently for each of the 12 leads:| Feature | Formula |
|---|---|
rms_lead_0 … rms_lead_11 | √(mean(lead_i²)) for i ∈ [0, 11] |
rms fallback for 1-lead inputs).
All features are standardized with StandardScaler (zero mean, unit variance) before being passed to the classifier.
Classifiers
- Random Forest
- XGBoost
Random Forest (Internal configuration:Random Forest is the default (
--model rf) fits an ensemble of 100 decision trees with maximum depth 15. It uses all available CPU cores (n_jobs=-1) and supports multi-output classification natively via RandomForestClassifier.--model rf) and requires no additional dependencies beyond scikit-learn.CLI Arguments
Root directory of the PTB-XL dataset. Must contain
ptbxl_database.csv, scp_statements.csv, and records100/ with .hea/.dat ECG records.Fraction of the PTB-XL training split (folds 1–8) to use for fitting the classifier.
0.1 corresponds to approximately 1,747 samples, matching the labeled data budget used by the deep learning baselines for a fair comparison.Random seed passed to
np.random.seed() and to the classifier’s random_state. Controls which samples are drawn by sample_labelled_indices and the internal tree-building randomness.Which classifier to train. Choices:
rf—RandomForestClassifier(n_estimators=100, max_depth=15). No extra dependencies.xgb—XGBClassifier(n_estimators=100, max_depth=5, lr=0.1). Requiresxgboostpackage; falls back to RF if not available.
Training Pipeline
Load metadata and splits
load_ptbxl_metadata reads the PTB-XL CSV files. make_default_splits partitions into train (folds 1–8), val (fold 9), and test (fold 10) using the standard PTB-XL protocol.Sample labeled training indices
sample_labelled_indices stratifies the training fold and returns label_fraction × N indices, matching the low-data regime used by CNN and SSL baselines.Extract features for all splits
Each ECG signal is loaded via
PTBXLRecordDataset, converted to a NumPy array, and passed to extract_ecg_features(). The result is a pandas DataFrame with one row per sample. Missing values are filled with 0.Standardize features
StandardScaler is fit on the training set only. The same scaler transforms validation and test sets to prevent data leakage.Fit the classifier
model.fit(X_train, y_train) where y_train is the multi-label binary matrix of shape (n_samples, 5).Feature Extraction Code
The fullextract_ecg_features function from train_traditional_ml.py:
When to Use Traditional ML
Fast Prototyping
No GPU required. Training and evaluation complete in minutes even on a laptop. Ideal for quickly validating dataset loading, preprocessing, and label quality before investing in deep learning training runs.
Interpretability
Random Forest feature importances reveal which statistical properties (energy, zero crossings, per-lead RMS) are most discriminative for each disease class. Useful for clinical hypothesis generation.
Low-Resource Deployment
The fitted scikit-learn model can be serialized with
joblib and deployed on edge devices or servers without PyTorch or GPU drivers.Baseline Benchmark
Establishes a feature-engineering floor. The gap between traditional ML and the SSL-pretrained CNN quantifies how much end-to-end representation learning adds beyond hand-designed features.
Next Steps
Supervised CNN Baseline
Train the CNN from scratch to see the gain from end-to-end deep feature learning over handcrafted features.
SSL Pretraining
Pretrain with SimCLR or BYOL to push performance beyond both traditional ML and supervised CNN baselines.