Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

The ECG encoder module provides the foundational backbone for all downstream tasks in SSRL-ECG. It defines a six-layer 1D convolutional encoder that maps raw multi-lead ECG signals to dense temporal feature maps, along with a supervised classifier head and a masked-reconstruction pretraining model. All classes live in ssrl_ecg.models.cnn.

ConvBlock

ConvBlock is the atomic building block used throughout the encoder. Each block applies a 1D convolution, batch normalization, and ReLU activation in sequence. Padding is computed automatically as kernel_size // 2 to preserve temporal resolution when stride=1.
from ssrl_ecg.models.cnn import ConvBlock
import torch

block = ConvBlock(in_ch=12, out_ch=64, kernel_size=7, stride=1)
x = torch.randn(4, 12, 1000)
out = block(x)  # shape: [4, 64, 1000]

Constructor Parameters

in_ch
int
required
Number of input channels (e.g., 12 for a standard 12-lead ECG).
out_ch
int
required
Number of output channels produced by the convolution.
kernel_size
int
default:"7"
Width of the 1D convolutional kernel. Padding is set to kernel_size // 2 so temporal length is preserved when stride=1.
stride
int
default:"1"
Stride of the convolution. Use stride=2 to halve temporal resolution.

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
Input tensor of shape [batch, in_ch, time].
output
Tensor
Output tensor of shape [batch, out_ch, time'], where time' = ceil(time / stride).

ECGEncoder1DCNN

ECGEncoder1DCNN is a six-layer 1D CNN that serves as the primary backbone for both self-supervised pretraining (SimCLR, BYOL) and supervised fine-tuning. Three stride-2 layers reduce temporal resolution by a factor of 8 overall, while channel width progressively doubles twice to capture increasingly abstract features.

Architecture

LayerKernelStrideIn channelsOut channels
1112in_chwidth
271widthwidth
372widthwidth × 2
451width × 2width × 2
552width × 2width × 4
631width × 4width × 4
At the default width=64, the encoder produces 256-channel feature maps (out_channels = width * 4). The overall temporal stride is (three stride-2 layers), so a 1 000-sample input yields feature maps of length 125.

Constructor Parameters

in_ch
int
default:"12"
Number of input ECG leads. Standard 12-lead ECGs use 12; single-lead setups may use 1.
width
int
default:"64"
Base channel width that controls model capacity. Final output channels equal width * 4 (256 at the default). Reduce to 32 for a lighter model; increase to 128 for higher capacity.

Attributes

AttributeTypeDescription
featuresnn.SequentialThe six-block convolutional stack.
out_channelsintNumber of output channels (width * 4).

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
Raw ECG tensor of shape [batch, in_ch, time].
output
Tensor
Feature map of shape [batch, width*4, time/8]. For the default settings this is [batch, 256, time/8].

Example Usage

from ssrl_ecg.models.cnn import ECGEncoder1DCNN, ECGClassifier
import torch

# Build encoder and run a forward pass
encoder = ECGEncoder1DCNN(in_ch=12, width=64)
x = torch.randn(4, 12, 1000)
features = encoder(x)   # shape: [4, 256, 125]
print(features.shape)   # torch.Size([4, 256, 125])

# Attach a supervised classification head
classifier = ECGClassifier(encoder, n_classes=5)
logits = classifier(x)  # shape: [4, 5]
print(logits.shape)     # torch.Size([4, 5])

ECGClassifier

ECGClassifier wraps any ECGEncoder1DCNN backbone with an adaptive average-pooling layer and a linear classification head. It is the standard model for supervised fine-tuning after SSL pretraining.
ECGEncoder1DCNN  →  AdaptiveAvgPool1d(1)  →  squeeze  →  Linear(out_channels, n_classes)

Constructor Parameters

encoder
ECGEncoder1DCNN
required
A pretrained (or randomly initialised) ECGEncoder1DCNN instance. The head dimension is inferred from encoder.out_channels.
n_classes
int
default:"5"
Number of output classes. For the PTB-XL five-class diagnostic task, use the default 5.

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
ECG tensor of shape [batch, in_ch, time].
logits
Tensor
Raw (un-sigmoidised) class logits of shape [batch, n_classes]. Apply torch.sigmoid for multi-label probabilities or torch.softmax for mutually exclusive classes.

Fine-tuning Example

import torch
from ssrl_ecg.models.cnn import ECGEncoder1DCNN, ECGClassifier

# Load a pretrained encoder checkpoint
encoder = ECGEncoder1DCNN(in_ch=12, width=64)
ckpt = torch.load("checkpoints/ssl_simclr.pt", map_location="cpu")
encoder.load_state_dict(ckpt["encoder"])

# Freeze encoder weights (linear evaluation protocol)
for param in encoder.parameters():
    param.requires_grad = False

# Attach classifier head and train
model = ECGClassifier(encoder, n_classes=5)
optimizer = torch.optim.Adam(model.head.parameters(), lr=1e-3)

x = torch.randn(32, 12, 1000)
labels = torch.zeros(32, 5)
logits = model(x)                         # [32, 5]
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels)
loss.backward()
optimizer.step()

SSLReconstructionModel

SSLReconstructionModel pairs the ECGEncoder1DCNN backbone with a mirrored transposed-convolution decoder for masked reconstruction pretraining. The decoder progressively upsamples the latent feature map back to the original signal length, restoring all input leads. This is an alternative SSL objective to contrastive methods like SimCLR and BYOL.

Decoder Architecture

The decoder mirrors the encoder’s 8× downsampling using three ConvTranspose1d layers (each with stride 2), followed by a final Conv1d that projects back to in_ch leads:
ConvTranspose1d(c, c/2, k=4, s=2, p=1)  →  ReLU
ConvTranspose1d(c/2, c/4, k=4, s=2, p=1)  →  ReLU
ConvTranspose1d(c/4, c/8, k=4, s=2, p=1)  →  ReLU
Conv1d(c/8, in_ch, k=3, p=1)
where c = encoder.out_channels. Output length is clamped or padded to exactly match the input length.

Constructor Parameters

in_ch
int
default:"12"
Number of ECG leads in both input and reconstructed output.
width
int
default:"64"
Base channel width passed through to the internal ECGEncoder1DCNN.

Forward

def forward(x: Tensor) -> Tensor
x
Tensor
required
Masked ECG tensor of shape [batch, in_ch, time]. Masking (e.g., zeroing out random segments) should be applied before this call.
x_hat
Tensor
Reconstructed ECG tensor of shape [batch, in_ch, time], with identical length to the input.

Example Usage

from ssrl_ecg.models.cnn import SSLReconstructionModel
import torch

model = SSLReconstructionModel(in_ch=12, width=64)
x = torch.randn(4, 12, 1000)

# Simulate random segment masking
mask_start, mask_len = 200, 400
x_masked = x.clone()
x_masked[:, :, mask_start:mask_start + mask_len] = 0.0

x_hat = model(x_masked)     # shape: [4, 12, 1000]
recon_loss = torch.nn.functional.mse_loss(x_hat, x)
recon_loss.backward()

Build docs developers (and LLMs) love