Skip to main content

Overview

The VisualizationGenerator class creates publication-quality matplotlib visualizations from experiment state. It produces three types of plots to analyze experiment performance over time.

Features

  • Headless rendering - Works in server environments without display
  • Graceful error handling - Each plot wrapped in try/except to prevent crashes
  • Clean styling - Professional defaults with customizable DPI and sizing
  • Automatic best result highlighting - Visual emphasis on top-performing experiments

Class Definition

VisualizationGenerator

from src.execution.visualization_generator import VisualizationGenerator

viz = VisualizationGenerator()
No parameters required. Matplotlib settings are configured automatically with professional defaults.

Methods

generate()

Generate all visualizations from experiment state.
plot_paths = viz.generate(
    state=experiment_state,
    output_dir=Path("outputs/")
)
state
ExperimentState
required
Complete experiment state after all iterations, containing:
  • experiments: List of all ExperimentResult objects
  • best_experiment: Name of best-performing experiment
  • config.primary_metric: Metric to visualize
output_dir
Path
required
Base output directory. Plots are saved to output_dir/plots/
Returns: list[Path] - List of paths to generated PNG files Generated plots:
  1. metric_progression.png - Line chart of metric over iterations
  2. model_comparison.png - Bar chart comparing model types
  3. improvement_over_baseline.png - Baseline vs best model comparison

Generated Visualizations

1. Metric Progression

Line plot showing primary metric value across all iterations. Features:
  • Blue line connecting all successful experiments
  • Gold star marking the best result
  • X-axis: Iteration number
  • Y-axis: Primary metric value
Example use case:
# Shows how RMSE improves over 10 iterations
# Best result at iteration 7 marked with star
Only successful experiments with the primary metric are included.

2. Model Comparison

Bar chart comparing best performance of each model type. Features:
  • Groups results by model type
  • Shows best metric achieved per model
  • Best overall model highlighted in green
  • Value labels on top of each bar
  • Rotated x-axis labels for readability
Example models compared:
  • LogisticRegression
  • RandomForestClassifier
  • XGBClassifier
  • etc.

3. Improvement Over Baseline

Side-by-side comparison of baseline vs best model. Features:
  • Two bars: baseline (orange) and best (green)
  • Percentage improvement in title
  • Value labels on bars
  • Automatic handling of “higher is better” vs “lower is better” metrics
Improvement calculation:
# For metrics where lower is better (RMSE, MAE)
improvement = ((baseline - best) / abs(baseline)) * 100

# For metrics where higher is better (accuracy, F1)
improvement = ((best - baseline) / abs(baseline)) * 100

Matplotlib Configuration

Default settings applied:
plt.rcParams.update({
    'figure.figsize': (10, 6),     # Width x Height in inches
    'figure.dpi': 100,              # Display DPI
    'savefig.bbox': 'tight',        # Remove whitespace
    'savefig.dpi': 150,             # High-quality output
})

Metric Direction Detection

Automatically determines if lower or higher values are better:
# Lower is better
lower_better = ["rmse", "mse", "mae", "log_loss", "error"]

# Checked via substring matching
is_lower_better = any(m in metric_name.lower() for m in lower_better)
Lower is better:
  • RMSE, MSE, MAE
  • Log Loss
  • Error rates
Higher is better:
  • Accuracy, Precision, Recall, F1
  • R² Score, ROC AUC

Error Handling

Each plot generation is wrapped individually:
try:
    path = self._plot_metric_progression(state, plots_dir)
    if path:
        plot_paths.append(path)
except Exception as e:
    print_warning(f"Failed to generate metric progression plot: {e}")
    plt.close('all')
    # Continue to next plot
If a plot fails to generate, a warning is printed but other plots continue. No exceptions are raised to the caller.

Complete Example

from pathlib import Path
from src.execution.visualization_generator import VisualizationGenerator
from src.orchestration.state import ExperimentState

# Load experiment state
state = ExperimentState.load(Path("state_abc123.json"))

# Generate visualizations
viz = VisualizationGenerator()
plot_paths = viz.generate(
    state=state,
    output_dir=Path("outputs/")
)

# Display results
print(f"Generated {len(plot_paths)} visualizations:")
for path in plot_paths:
    print(f"  - {path.name}")

# Example output:
# Generated 3 visualizations:
#   - metric_progression.png
#   - model_comparison.png
#   - improvement_over_baseline.png

Integration with MLflow

Visualization paths can be logged to MLflow:
from src.persistence.mlflow_tracker import MLflowTracker

tracker = MLflowTracker(experiment_name="my_experiment")

# Generate plots
plot_paths = viz.generate(state, output_dir)

# Log to MLflow
tracker.log_visualizations(plot_paths)

Plot Generation Conditions

Plots are only generated when sufficient data exists:

Metric Progression

  • Requires: At least 1 successful experiment with primary metric
  • Skips if: No successful experiments or primary_metric not set

Model Comparison

  • Requires: At least 1 successful experiment with primary metric
  • Skips if: No successful experiments or primary_metric not set
  • Groups by: model_type field from results

Improvement Over Baseline

  • Requires:
    • First experiment (baseline) succeeded
    • Baseline has primary metric
    • Best metric exists and is non-zero
  • Skips if: Any requirement not met

Customizing Output Directory

# Default: plots/ subdirectory
plot_paths = viz.generate(state, Path("outputs/"))
# Creates: outputs/plots/metric_progression.png, etc.

# Custom location
plot_paths = viz.generate(state, Path("/tmp/experiment_123"))
# Creates: /tmp/experiment_123/plots/metric_progression.png, etc.
The plots/ subdirectory is always created automatically.

Accessing Individual Plots

# Generate all plots
plot_paths = viz.generate(state, output_dir)

# Find specific plot
for path in plot_paths:
    if path.name == "metric_progression.png":
        print(f"Metric progression saved to: {path}")
    elif path.name == "model_comparison.png":
        print(f"Model comparison saved to: {path}")
    elif path.name == "improvement_over_baseline.png":
        print(f"Improvement plot saved to: {path}")

Headless Rendering

The generator uses the Agg backend for headless environments:
import matplotlib
matplotlib.use('Agg')  # Must be before pyplot import
import matplotlib.pyplot as plt
This allows visualization generation on servers without graphical displays (Docker, SSH sessions, etc.).

Colors and Styling

Standard colors:
  • Primary line/bars: #2196F3 (blue)
  • Best result marker: #4CAF50 (green)
  • Baseline: #FF9800 (orange)
  • Best model: #4CAF50 (green)
Typography:
  • Font: System default
  • Value labels: 9-11pt, bold for emphasis
  • Grid: Alpha 0.3 for subtle guidance

Source Location

~/workspace/source/src/execution/visualization_generator.py

Build docs developers (and LLMs) love