Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

Cardiovascular disease datasets are inherently class-imbalanced — rare conditions like bundle branch blocks appear far less frequently than normal sinus rhythm. The ssrl_ecg.models.losses module provides four specialised loss functions and a utility for computing class weights, all designed for multi-label binary classification where each ECG can carry multiple simultaneous diagnoses.

FocalLoss

FocalLoss adapts the focal loss from RetinaNet (Lin et al., 2017) to multi-label ECG classification via sigmoid activation. It down-weights easy, well-classified examples using the modulating factor (1 − p_t)^γ, concentrating training on hard or misclassified samples. This is the recommended loss for datasets with moderate-to-severe imbalance. The loss is computed element-wise as:
FL(p_t) = α · (1 − p_t)^γ · BCE(logits, targets)
where p_t = σ(logit) for positive targets and p_t = 1 − σ(logit) for negatives.

When to use

Use FocalLoss when a handful of classes dominate the dataset. The default alpha=0.25, gamma=2.0 follows the RetinaNet recipe and works well out of the box for most ECG tasks.

Constructor Parameters

alpha
float
default:"0.25"
Scalar weighting factor in (0, 1) applied uniformly across all elements. Balances the contribution of positive vs. negative examples at the population level.
gamma
float
default:"2.0"
Focusing exponent. Higher values reduce the loss contribution from easy examples more aggressively. gamma=0 recovers standard weighted BCE.
reduction
str
default:"\"mean\""
Reduction to apply across the batch. One of "mean", "sum", or "none". Use "none" to obtain per-element losses for custom aggregation.

Forward

def forward(logits: Tensor, targets: Tensor) -> Tensor
logits
Tensor
required
Raw model predictions (before sigmoid) of shape [batch, num_classes].
targets
Tensor
required
Binary ground-truth labels of shape [batch, num_classes]. Each entry must be 0 or 1; floating-point soft labels are also accepted.
loss
Tensor
Scalar focal loss (or per-element tensor if reduction="none").

WeightedBCELoss

WeightedBCELoss wraps nn.BCEWithLogitsLoss to apply per-class scalar weights to the standard binary cross-entropy. It is the simplest re-weighting strategy and serves as a useful baseline for diagnosing class imbalance. Two weighting mechanisms are supported:
  • class_weights: per-class multipliers applied to the computed loss matrix [batch, classes].
  • pos_weight: passed directly to BCEWithLogitsLoss, which scales the loss on positive examples per class.

When to use

Use WeightedBCELoss as a lightweight alternative to FocalLoss when inverse-frequency or custom weights are known ahead of time. Combine with compute_class_weights to derive weights automatically from the training labels.

Constructor Parameters

class_weights
Tensor | None
default:"None"
Per-class loss multipliers of shape [num_classes]. Higher weights increase the penalty for misclassifying the corresponding class. If None, equal weights are used.
reduction
str
default:"\"mean\""
Reduction applied after weighting. One of "mean", "sum", or "none".
pos_weight
Tensor | None
default:"None"
Per-class positive-example weights of shape [num_classes], forwarded to nn.BCEWithLogitsLoss. Cannot be combined with class_weights in the same call.

Forward

def forward(logits: Tensor, targets: Tensor) -> Tensor
logits
Tensor
required
Raw model predictions of shape [batch, num_classes].
targets
Tensor
required
Binary labels of shape [batch, num_classes].
loss
Tensor
Weighted BCE loss scalar (or per-sample tensor if reduction="none").

ClassBalancedLoss

ClassBalancedLoss implements the effective-number-of-samples re-weighting scheme from Cui et al. (CVPR 2019). Rather than using raw class frequencies, it computes effective sample counts that account for the diminishing marginal benefit of additional samples from the same class:
E_n = (1 − β^n) / (1 − β)
weight_c = (1 − β) / E_{n_c}
The weights are then normalised so they sum to num_classes. Higher beta values give more weight to rare classes.

When to use

ClassBalancedLoss is preferred when class counts vary by orders of magnitude (e.g., 10 vs. 10 000 samples per class). It is more robust than raw inverse-frequency weighting because the effective-number formula saturates for very large classes.

Constructor Parameters

class_counts
Tensor
required
Number of positive training examples per class, shape [num_classes]. Computed, for example, by targets.sum(dim=0).
beta
float
default:"0.9999"
Hyper-parameter in (0, 1). Controls how quickly effective number saturates. Larger values (closer to 1) up-weight rare classes more aggressively.
reduction
str
default:"\"mean\""
One of "mean", "sum", or "none".

Forward

def forward(logits: Tensor, targets: Tensor) -> Tensor
logits
Tensor
required
Raw model predictions of shape [batch, num_classes].
targets
Tensor
required
Binary labels of shape [batch, num_classes].
loss
Tensor
Class-balanced BCE loss.

DynamicWeightedLoss

DynamicWeightedLoss supports curriculum learning by updating class weights during training based on per-class validation performance. Classes with lower metrics (e.g., F1 score) receive proportionally higher weights, automatically shifting focus to underperforming classes as training progresses.

When to use

Use DynamicWeightedLoss when you have a validation set and want to adaptively re-balance the loss during training. Call update_weights after each validation epoch with the per-class F1 or AUROC scores.

Constructor Parameters

num_classes
int
required
Total number of output classes. Initialises uniform weights [1, 1, ..., 1].
base_criteria
str
default:"\"bce\""
Underlying loss function applied before dynamic weighting. One of:
  • "bce" — standard binary cross-entropy.
  • "focal"FocalLoss with default alpha=0.25, gamma=2.0.
reduction
str
default:"\"mean\""
One of "mean", "sum", or "none".

update_weights

def update_weights(per_class_metrics: Tensor) -> None
per_class_metrics
Tensor
required
Per-class performance metric (e.g., F1 scores) of shape [num_classes]. Weights are set proportional to 1 / (metric + ε) and then normalised.

Forward

def forward(logits: Tensor, targets: Tensor) -> Tensor
logits
Tensor
required
Raw predictions of shape [batch, num_classes].
targets
Tensor
required
Binary labels of shape [batch, num_classes].
loss
Tensor
Dynamically-weighted loss scalar.

compute_class_weights

compute_class_weights derives per-class weights from the training label matrix using one of three inverse-frequency strategies. The resulting weights are normalised so their sum equals num_classes, making them directly comparable across datasets.
def compute_class_weights(
    targets: Tensor,
    method: str = "inverse_frequency",
) -> Tensor

Parameters

targets
Tensor
required
Binary training labels of shape [num_samples, num_classes]. Accepts both torch.Tensor and numpy.ndarray inputs.
method
str
default:"\"inverse_frequency\""
Weighting strategy. One of:
ValueFormulaNotes
"inverse_frequency"w_c = 1 / n_cStrongest re-weighting; sensitive to very rare classes
"log_inverse"w_c = 1 / log(n_c + 2)Softer re-weighting; good default for moderate imbalance
"sqrt_inverse"w_c = 1 / √n_cIntermediate between the above two
weights
Tensor
Per-class weights of shape [num_classes], normalised so that weights.sum() == num_classes.

Usage Examples

FocalLoss with Pre-computed Class Weights

from ssrl_ecg.models.losses import FocalLoss, compute_class_weights
import torch

# Compute inverse-frequency weights from training labels
targets = torch.zeros(1747, 5)      # training label matrix [samples, classes]
# Simulate class counts: [800, 300, 200, 150, 100, 197]
targets[:800, 0]  = 1
targets[:300, 1]  = 1
targets[:200, 2]  = 1
targets[:150, 3]  = 1
targets[:100, 4]  = 1

weights = compute_class_weights(targets, method='inverse_frequency')
print(weights)   # tensor([...]) — rare classes get higher weights

# Build FocalLoss and run a forward pass
criterion = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(64, 5)
labels = torch.zeros(64, 5)
loss   = criterion(logits, labels)
print(f"Focal loss: {loss.item():.4f}")

ClassBalancedLoss from Training Set

from ssrl_ecg.models.losses import ClassBalancedLoss
import torch

# Count positive samples per class from the training set
train_labels  = torch.randint(0, 2, (1747, 5)).float()
class_counts  = train_labels.sum(dim=0)    # [5]

criterion = ClassBalancedLoss(class_counts=class_counts, beta=0.9999)

logits = torch.randn(64, 5)
labels = torch.randint(0, 2, (64, 5)).float()
loss   = criterion(logits, labels)
print(f"Class-balanced loss: {loss.item():.4f}")

DynamicWeightedLoss with Curriculum Update

from ssrl_ecg.models.losses import DynamicWeightedLoss
import torch

criterion = DynamicWeightedLoss(num_classes=5, base_criteria='bce')

# --- After each validation epoch ---
# Simulate per-class F1 scores from validation
val_f1_scores = torch.tensor([0.85, 0.60, 0.45, 0.30, 0.72])

# Update: low-F1 classes get higher loss weight
criterion.update_weights(val_f1_scores)

# --- Training step ---
logits = torch.randn(64, 5)
labels = torch.randint(0, 2, (64, 5)).float()
loss   = criterion(logits, labels)
print(f"Dynamic weighted loss: {loss.item():.4f}")

WeightedBCELoss with pos_weight

from ssrl_ecg.models.losses import WeightedBCELoss, compute_class_weights
import torch

train_labels  = torch.randint(0, 2, (1747, 5)).float()
class_counts  = train_labels.sum(dim=0)
neg_counts    = train_labels.shape[0] - class_counts

# pos_weight = num_negatives / num_positives per class
pos_weight = neg_counts / (class_counts + 1e-8)

criterion = WeightedBCELoss(pos_weight=pos_weight)

logits = torch.randn(32, 5)
labels = torch.randint(0, 2, (32, 5)).float()
loss   = criterion(logits, labels)
print(f"Weighted BCE loss: {loss.item():.4f}")

Choosing the Right Loss

ScenarioRecommended loss
Mild imbalance (< 5× ratio)WeightedBCELoss with compute_class_weights
Moderate imbalance (5–50×)FocalLoss with gamma=2.0
Severe imbalance (> 50×)ClassBalancedLoss with beta=0.9999
Varying performance per classDynamicWeightedLoss with periodic update_weights
When in doubt, start with FocalLoss(alpha=0.25, gamma=2.0) — it is parameter-light and robust across a wide range of imbalance levels. Use compute_class_weights with method="log_inverse" to further scale the alpha term per class when the top-level alpha scalar is insufficient.

Build docs developers (and LLMs) love