Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/Tumo505/SSL-for-ECG-classification/llms.txt

Use this file to discover all available pages before exploring further.

The ssrl_ecg.visualization module provides utilities for producing consistent, publication-ready figures from ECG classification experiments. This page documents set_publication_style, which configures matplotlib globally so that every subsequent figure shares the same typographic and aesthetic standards, and plot_roc_curve, which renders a labelled ROC curve on any axes object, handling both binary and multi-class inputs automatically.
from ssrl_ecg.visualization import set_publication_style, plot_roc_curve

set_publication_style

set_publication_style() -> None
Applies a curated set of matplotlib.rcParams updates that configure every subsequent figure for print or journal submission quality. Call this once at the top of any script or notebook that generates figures. Returns None. All effects are applied as global side effects on plt.rcParams.

rcParams applied

ParameterValueEffect
figure.dpi300Screen rendering at print resolution
figure.figsize(8, 6)Default figure dimensions in inches
font.size11Base font size for all text elements
font.family"sans-serif"Clean sans-serif typeface matching most journal styles
axes.labelsize12Axis label font size
axes.titlesize13Axes title font size
xtick.labelsize10X-axis tick label size
ytick.labelsize10Y-axis tick label size
legend.fontsize10Legend entry font size
lines.linewidth2Default line weight
lines.markersize6Default marker size
axes.gridTrueGrid enabled on all axes
grid.alpha0.3Subtle, non-distracting grid
savefig.dpi300DPI used when saving to file
savefig.bbox"tight"Auto-crop whitespace on save

Example

from ssrl_ecg.visualization import set_publication_style
import matplotlib.pyplot as plt

set_publication_style()

fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1])
ax.set_title("Example")
plt.savefig("figures/example.png")  # saved at 300 DPI, tight bbox, no extra config needed
Place set_publication_style() immediately after your imports and before any plt.subplots() call. Because it modifies global state, you only need to call it once per script. In Jupyter notebooks, call it in its own cell near the top so that all subsequent plotting cells inherit the style automatically.

plot_roc_curve

plot_roc_curve(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    label: str = "",
    ax=None,
    color: str = "b",
) -> tuple
Plots a Receiver Operating Characteristic (ROC) curve on a matplotlib axes object. Automatically handles multi-class inputs by column-wise averaging before computing the curve. Auto-labels the plotted line with the computed AUC score.
y_true
np.ndarray
required
Ground-truth labels. Accepts:
  • 1D [n_samples] — binary labels {0, 1}
  • 2D [n_samples, n_classes] — multi-label binary matrix; columns are averaged to a single score per sample before curve computation
y_prob
np.ndarray
required
Predicted scores or probabilities. Accepts:
  • 1D [n_samples] — binary classifier output
  • 2D [n_samples, n_classes] — multi-class probability matrix; columns are averaged column-wise to a single scalar per sample
See Multi-class averaging below for details on this behaviour.
label
str
default:"\"\""
Name of the model or method to display in the legend. Displayed as "{label} (AUC = 0.XXX)". Leave empty to show only the AUC value.
ax
matplotlib.axes.Axes
default:"None"
Axes object on which to draw. When None, a new figure and axes are created internally. Pass an existing ax to overlay multiple curves on a single plot.
color
str
default:"\"b\""
Matplotlib color string or hex code for the ROC curve line. Accepts any format accepted by matplotlib.axes.Axes.plot (e.g. "blue", "#1f77b4", "C0").
Returns
fig
matplotlib.figure.Figure or None
The figure object. Returns None when an existing ax was passed (the figure is managed by the caller in that case).
ax
matplotlib.axes.Axes
The axes object containing the plotted curve. Always returned regardless of whether ax was supplied.
roc_auc
float
The scalar AUC value computed for this curve. Useful for programmatic comparison or logging without inspecting the legend text.

Multi-class averaging

When y_prob is a 2D array with more than one column, plot_roc_curve reduces it to a 1D score vector by taking the column-wise mean (y_prob.mean(axis=1)). The same averaging is applied to y_true when it is 2D. This produces a single aggregate ROC curve representing overall classifier performance across all classes — the equivalent of a macro-level summary curve.
Column-wise averaging is a summary heuristic, not a standard multi-class ROC method. For rigorous per-class or One-vs-Rest curves (e.g. for a paper’s supplementary material), compute roc_curve and auc independently for each class column and plot them separately. The averaging approach here is designed for fast, holistic model comparison plots such as a SimCLR vs BYOL comparison figure.

Axes decoration applied

Every call to plot_roc_curve applies the following decoration to the target axes:
ElementValue
X-axis label"False Positive Rate"
Y-axis label"True Positive Rate"
Title"Receiver Operating Characteristic Curve"
Diagonal reference lineDashed black, alpha=0.5, labelled "Random"
Axis limits[-0.02, 1.02] on both axes
GridTrue, alpha=0.3
LegendLower-right, fontsize=10, includes AUC value
Because plot_roc_curve sets the title, axis labels, and legend on every call, overlaying multiple curves (e.g. SimCLR, BYOL, supervised baseline) on the same axes is safe — each call simply overwrites the static decoration elements with identical values while adding a new line to the legend.

Complete Multi-Model Comparison Example

import numpy as np
import matplotlib.pyplot as plt
from ssrl_ecg.visualization import set_publication_style, plot_roc_curve

# Apply publication style once
set_publication_style()

# Simulated evaluation outputs — replace with real model outputs
n_samples, n_classes = 500, 5
y_true          = np.random.randint(0, 2, size=(n_samples, n_classes))
y_prob_simclr   = np.random.rand(n_samples, n_classes)
y_prob_byol     = np.random.rand(n_samples, n_classes)
y_prob_baseline = np.random.rand(n_samples, n_classes)

# Create a single axes and overlay three curves
fig, ax = plt.subplots(figsize=(8, 6))

_, ax, auc_simclr   = plot_roc_curve(y_true, y_prob_simclr,   label="SimCLR",              ax=ax, color="blue")
_, ax, auc_byol     = plot_roc_curve(y_true, y_prob_byol,     label="BYOL",                ax=ax, color="orange")
_, ax, auc_baseline = plot_roc_curve(y_true, y_prob_baseline, label="Supervised Baseline", ax=ax, color="green")

ax.legend()
plt.savefig("figures/roc_comparison.png", dpi=300, bbox_inches="tight")

print(f"SimCLR AUC:   {auc_simclr:.4f}")
print(f"BYOL AUC:     {auc_byol:.4f}")
print(f"Baseline AUC: {auc_baseline:.4f}")

Single-Model Example

import numpy as np
import matplotlib.pyplot as plt
from ssrl_ecg.visualization import set_publication_style, plot_roc_curve

set_publication_style()

y_true = np.array([0, 0, 1, 1, 1, 0, 1, 0])
y_prob = np.array([0.1, 0.4, 0.35, 0.8, 0.9, 0.2, 0.7, 0.3])

fig, ax, roc_auc = plot_roc_curve(y_true, y_prob, label="My Model", color="steelblue")

print(f"AUC: {roc_auc:.4f}")
plt.savefig("figures/roc_single.png")

Collecting AUC Scores Programmatically

When running a sweep over multiple checkpoints or seeds, capture the returned roc_auc values for downstream aggregation without parsing the legend:
from ssrl_ecg.visualization import set_publication_style, plot_roc_curve
import numpy as np
import matplotlib.pyplot as plt

set_publication_style()
fig, ax = plt.subplots(figsize=(8, 6))

results = {}
for name, y_prob, color in model_outputs:          # your iterable of (name, probs, color)
    _, ax, auc_val = plot_roc_curve(
        y_true, y_prob, label=name, ax=ax, color=color
    )
    results[name] = auc_val

ax.legend(loc="lower right")
plt.savefig("figures/roc_sweep.png")

# results is now a dict: {"ModelA": 0.87, "ModelB": 0.91, ...}
best = max(results, key=results.get)
print(f"Best model: {best} (AUC = {results[best]:.4f})")

Build docs developers (and LLMs) love