Cardiovascular disease datasets are inherently class-imbalanced — rare conditions like bundle branch blocks appear far less frequently than normal sinus rhythm. TheDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt
Use this file to discover all available pages before exploring further.
ssrl_ecg.models.losses module provides four specialised loss functions and a utility for computing class weights, all designed for multi-label binary classification where each ECG can carry multiple simultaneous diagnoses.
FocalLoss
FocalLoss adapts the focal loss from RetinaNet (Lin et al., 2017) to multi-label ECG classification via sigmoid activation. It down-weights easy, well-classified examples using the modulating factor (1 − p_t)^γ, concentrating training on hard or misclassified samples. This is the recommended loss for datasets with moderate-to-severe imbalance.
The loss is computed element-wise as:
p_t = σ(logit) for positive targets and p_t = 1 − σ(logit) for negatives.
When to use
UseFocalLoss when a handful of classes dominate the dataset. The default alpha=0.25, gamma=2.0 follows the RetinaNet recipe and works well out of the box for most ECG tasks.
Constructor Parameters
Scalar weighting factor in
(0, 1) applied uniformly across all elements.
Balances the contribution of positive vs. negative examples at the population level.Focusing exponent. Higher values reduce the loss contribution from easy examples
more aggressively.
gamma=0 recovers standard weighted BCE.Reduction to apply across the batch. One of
"mean", "sum", or "none".
Use "none" to obtain per-element losses for custom aggregation.Forward
Raw model predictions (before sigmoid) of shape
[batch, num_classes].Binary ground-truth labels of shape
[batch, num_classes]. Each entry must be
0 or 1; floating-point soft labels are also accepted.Scalar focal loss (or per-element tensor if
reduction="none").WeightedBCELoss
WeightedBCELoss wraps nn.BCEWithLogitsLoss to apply per-class scalar weights to the standard binary cross-entropy. It is the simplest re-weighting strategy and serves as a useful baseline for diagnosing class imbalance.
Two weighting mechanisms are supported:
class_weights: per-class multipliers applied to the computed loss matrix[batch, classes].pos_weight: passed directly toBCEWithLogitsLoss, which scales the loss on positive examples per class.
When to use
UseWeightedBCELoss as a lightweight alternative to FocalLoss when inverse-frequency or custom weights are known ahead of time. Combine with compute_class_weights to derive weights automatically from the training labels.
Constructor Parameters
Per-class loss multipliers of shape
[num_classes]. Higher weights increase the
penalty for misclassifying the corresponding class. If None, equal weights are used.Reduction applied after weighting. One of
"mean", "sum", or "none".Per-class positive-example weights of shape
[num_classes], forwarded to
nn.BCEWithLogitsLoss. Cannot be combined with class_weights in the same call.Forward
Raw model predictions of shape
[batch, num_classes].Binary labels of shape
[batch, num_classes].Weighted BCE loss scalar (or per-sample tensor if
reduction="none").ClassBalancedLoss
ClassBalancedLoss implements the effective-number-of-samples re-weighting scheme from Cui et al. (CVPR 2019). Rather than using raw class frequencies, it computes effective sample counts that account for the diminishing marginal benefit of additional samples from the same class:
num_classes. Higher beta values give more weight to rare classes.
When to use
ClassBalancedLoss is preferred when class counts vary by orders of magnitude (e.g., 10 vs. 10 000 samples per class). It is more robust than raw inverse-frequency weighting because the effective-number formula saturates for very large classes.
Constructor Parameters
Number of positive training examples per class, shape
[num_classes].
Computed, for example, by targets.sum(dim=0).Hyper-parameter in
(0, 1). Controls how quickly effective number saturates.
Larger values (closer to 1) up-weight rare classes more aggressively.One of
"mean", "sum", or "none".Forward
Raw model predictions of shape
[batch, num_classes].Binary labels of shape
[batch, num_classes].Class-balanced BCE loss.
DynamicWeightedLoss
DynamicWeightedLoss supports curriculum learning by updating class weights during training based on per-class validation performance. Classes with lower metrics (e.g., F1 score) receive proportionally higher weights, automatically shifting focus to underperforming classes as training progresses.
When to use
UseDynamicWeightedLoss when you have a validation set and want to adaptively re-balance the loss during training. Call update_weights after each validation epoch with the per-class F1 or AUROC scores.
Constructor Parameters
Total number of output classes. Initialises uniform weights
[1, 1, ..., 1].Underlying loss function applied before dynamic weighting. One of:
"bce"— standard binary cross-entropy."focal"—FocalLosswith defaultalpha=0.25, gamma=2.0.
One of
"mean", "sum", or "none".update_weights
Per-class performance metric (e.g., F1 scores) of shape
[num_classes].
Weights are set proportional to 1 / (metric + ε) and then normalised.Forward
Raw predictions of shape
[batch, num_classes].Binary labels of shape
[batch, num_classes].Dynamically-weighted loss scalar.
compute_class_weights
compute_class_weights derives per-class weights from the training label matrix using one of three inverse-frequency strategies. The resulting weights are normalised so their sum equals num_classes, making them directly comparable across datasets.
Parameters
Binary training labels of shape
[num_samples, num_classes]. Accepts both
torch.Tensor and numpy.ndarray inputs.Weighting strategy. One of:
| Value | Formula | Notes |
|---|---|---|
"inverse_frequency" | w_c = 1 / n_c | Strongest re-weighting; sensitive to very rare classes |
"log_inverse" | w_c = 1 / log(n_c + 2) | Softer re-weighting; good default for moderate imbalance |
"sqrt_inverse" | w_c = 1 / √n_c | Intermediate between the above two |
Per-class weights of shape
[num_classes], normalised so that
weights.sum() == num_classes.Usage Examples
FocalLoss with Pre-computed Class Weights
ClassBalancedLoss from Training Set
DynamicWeightedLoss with Curriculum Update
WeightedBCELoss with pos_weight
Choosing the Right Loss
| Scenario | Recommended loss |
|---|---|
| Mild imbalance (< 5× ratio) | WeightedBCELoss with compute_class_weights |
| Moderate imbalance (5–50×) | FocalLoss with gamma=2.0 |
| Severe imbalance (> 50×) | ClassBalancedLoss with beta=0.9999 |
| Varying performance per class | DynamicWeightedLoss with periodic update_weights |