Skip to main content

get_dataset

Loads semantic anomaly datasets with automatic train subset filtering.
from laft.utils import get_dataset

data = get_dataset(
    dataset_name="color_mnist",
    dataset_config=None,
    dataset_root="./data"
)

Parameters

dataset_name
Literal['color_mnist', 'waterbirds', 'celeba']
required
Name of the semantic anomaly dataset to load.
dataset_config
dict | None
default:"None"
Dataset-specific configuration. When None, uses default configuration.
dataset_kwargs
dict | None
default:"None"
Additional keyword arguments passed to dataset constructor.
transform
Callable | None
default:"None"
Transform to apply to images.
dataset_root
str
default:"./data"
Root directory for dataset files.
splits
Sequence[Literal['train', 'valid', 'test']]
default:"('train', 'test')"
Dataset splits to load.
verbose
bool
default:"True"
Whether to print progress information.
print_fn
Callable
default:"print"
Function to use for printing (defaults to built-in print).

Returns

data
dict[str, tuple]
Dictionary mapping split names to (subset, attrs) tuples.
  • For train split: subset contains only normal samples, attrs are all False
  • For other splits: subset is the full dataset, attrs are the original attributes

Examples

Load train and test splits

from laft.utils import get_dataset
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data = get_dataset(
    dataset_name="color_mnist",
    transform=transform,
    dataset_root="./data"
)

# Access splits
train_subset, train_attrs = data['train']
test_subset, test_attrs = data['test']

print(f"Train samples (normal only): {len(train_subset)}")
print(f"Test samples (all): {len(test_subset)}")

Load all splits with custom config

custom_config = {
    "number": {i: i >= 5 for i in range(10)},
    "color": {"red": False, "green": True, "blue": True}
}

data = get_dataset(
    dataset_name="color_mnist",
    dataset_config=custom_config,
    splits=["train", "valid", "test"],
    verbose=True
)

for split, (subset, attrs) in data.items():
    print(f"{split}: {len(subset)} samples")

Use with DataLoader

from torch.utils.data import DataLoader

data = get_dataset("color_mnist", transform=transform)
train_subset, _ = data['train']

train_loader = DataLoader(
    train_subset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

for images, targets in train_loader:
    print(f"Batch shape: {images.shape}")
    break

get_clip_cached_features

Computes and caches CLIP features for datasets, enabling fast repeated experiments.
from laft.utils import get_clip_cached_features

model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    device="cuda"
)

Parameters

model_name
str
required
Name of the CLIP model to use (e.g., "ViT-B-16", "ViT-L-14").
dataset_name
Literal['color_mnist', 'waterbirds', 'celeba']
required
Name of the dataset.
splits
Sequence[Literal['train', 'train-all', 'valid', 'test']]
default:"('train', 'test')"
Dataset splits to compute features for.
  • train: Only normal samples from training set
  • train-all: All samples from training set
  • valid: Validation set
  • test: Test set
device
str | torch.device
default:"'cuda' if available else 'cpu'"
Device to run model on.
model_root
str | None
default:"./checkpoints/open_clip"
Directory to download/load CLIP model from.
dataset_root
str
default:"./data"
Root directory for dataset files.
dataset_config
dict | None
default:"None"
Dataset-specific configuration.
dataset_kwargs
dict | None
default:"None"
Additional keyword arguments for dataset.
verbose
bool
default:"True"
Whether to print progress information.
print_fn
Callable
default:"print"
Function to use for printing.
cache_root
str
default:"./.cache"
Directory to store cached features.
flush
bool
default:"False"
Whether to recompute features even if cache exists.

Returns

model
CLIP model
The loaded CLIP model.
data
dict[str, tuple[torch.Tensor, torch.Tensor]]
Dictionary mapping split names to (features, attrs) tuples.
  • features: Tensor of shape [num_samples, feature_dim] with CLIP image embeddings
  • attrs: Tensor of shape [num_samples, num_attrs] with attribute labels

Caching System

The function uses MD5 hashing of configuration to create unique cache keys:
keyset = {
    "model_name": "ViT-B-16",
    "dataset_name": "color_mnist",
    "dataset_config": {...},  # Only for color_mnist
    # ... other dataset_kwargs
}

hashkey = md5(json.dumps(keyset, sort_keys=True)).hexdigest()
# Features saved to: .cache/{hashkey}/train.pt, .cache/{hashkey}/test.pt

Examples

Basic usage with caching

from laft.utils import get_clip_cached_features

# First run: computes features and saves to cache
model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    device="cuda",
    verbose=True
)

# Second run: loads from cache (much faster)
model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    device="cuda",
    verbose=True
)
# Output: Loading train from cache '...'
#         Loading test from cache '...'

# Access cached features
train_features, train_attrs = data['train']
test_features, test_attrs = data['test']

print(f"Train features shape: {train_features.shape}")  # [N, 512] for ViT-B-16
print(f"Test features shape: {test_features.shape}")

Load all splits including train-all

model, data = get_clip_cached_features(
    model_name="ViT-L-14",
    dataset_name="color_mnist",
    splits=["train", "train-all", "test"],
    device="cuda"
)

# train: only normal samples
# train-all: all training samples (normal + anomaly)
train_normal, _ = data['train']
train_all, _ = data['train-all']
test, _ = data['test']

print(f"Normal train samples: {len(train_normal)}")
print(f"All train samples: {len(train_all)}")
print(f"Test samples: {len(test)}")

Force recompute with flush

# Recompute features even if cache exists
model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    flush=True,  # Ignore cache
    verbose=True
)

Use custom dataset config

custom_config = {
    "number": {i: i >= 7 for i in range(10)},
    "color": {"red": False, "green": True, "blue": True}
}

# Different config = different cache key
model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    dataset_config=custom_config,
    device="cuda"
)

Use features for anomaly detection

import torch

model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist"
)

train_features, _ = data['train']
test_features, test_attrs = data['test']

# Compute distances to training set
mean_train = train_features.mean(dim=0)
distances = torch.cdist(test_features, mean_train.unsqueeze(0))

# Evaluate anomaly detection
from sklearn.metrics import roc_auc_score

scores = distances.squeeze().cpu().numpy()
labels = test_attrs.any(dim=1).cpu().numpy()  # Any anomaly attribute = anomaly

auroc = roc_auc_score(labels, scores)
print(f"AUROC: {auroc:.4f}")

build_table

Builds a formatted table from experiment metrics for easy comparison.
from laft.utils import build_table

table = build_table(
    metrics=results,
    types=["auroc", "auprc", "fpr95"]
)

Parameters

metrics
Mapping[str, Mapping[str, Mapping[str, float] | Sequence[Mapping[str, float]]]]
required
Nested dictionary of metrics.Structure: {method: {experiment: {metric: value} or [{metric: value}, ...]}}
  • First level: method/model names (e.g., “baseline”, “laft”)
  • Second level: experiment configurations (e.g., “color_mnist/guide_number”)
  • Third level: metric values (dict for single run, list of dicts for multiple runs)
group_headers
Sequence[str] | None
default:"None"
Column headers for grouping (e.g., ["Dataset", "Guidance"]).Length must match the number of levels in experiment names (split by /).
label_headers
Sequence[str] | None
default:"None"
Headers for method names. Defaults to capitalized method names.
types
Sequence[str]
default:"('auroc', 'auprc', 'fpr95')"
Metric types to include in the table.Supported: "auroc", "auprc", "accuracy", "f1", "fpr95"
meanfmt
str
default:"'5.1f'"
Format string for mean values (e.g., "5.1f"" 95.3").
stdfmt
str
default:"'3.1f'"
Format string for standard deviation values (e.g., "3.1f""2.1").

Returns

table
str
Formatted table string with metrics.

Examples

Single run metrics

from laft.utils import build_table

metrics = {
    "baseline": {
        "color_mnist/guide_number": {"auroc": 0.851, "auprc": 0.792, "fpr95": 0.423},
        "color_mnist/guide_color": {"auroc": 0.823, "auprc": 0.765, "fpr95": 0.487},
    },
    "laft": {
        "color_mnist/guide_number": {"auroc": 0.923, "auprc": 0.891, "fpr95": 0.231},
        "color_mnist/guide_color": {"auroc": 0.897, "auprc": 0.854, "fpr95": 0.298},
    },
}

table = build_table(
    metrics,
    group_headers=["Dataset", "Guidance"],
    types=["auroc", "auprc", "fpr95"]
)

print(table)
Output:
                                  Baseline                 Laft
Dataset      Guidance     AUROC  AUPRC  FPR95  AUROC  AUPRC  FPR95
color_mnist  guide_number  85.1   79.2   42.3   92.3   89.1   23.1
color_mnist  guide_color   82.3   76.5   48.7   89.7   85.4   29.8

Multiple runs with mean ± std

metrics = {
    "laft": {
        "color_mnist/guide_number": [
            {"auroc": 0.921, "auprc": 0.889, "fpr95": 0.235},
            {"auroc": 0.925, "auprc": 0.893, "fpr95": 0.227},
            {"auroc": 0.923, "auprc": 0.891, "fpr95": 0.231},
        ],
    },
}

table = build_table(
    metrics,
    group_headers=["Dataset", "Guidance"],
    types=["auroc", "auprc", "fpr95"],
    meanfmt="5.1f",
    stdfmt="3.1f"
)

print(table)
# Output: 92.3 ± 1.6  89.1 ± 1.6  23.1 ± 3.3

Custom formatting

# High precision formatting
table = build_table(
    metrics,
    types=["auroc"],
    meanfmt="6.3f",  # More decimal places
    stdfmt="5.3f"
)

print(table)
# Output: 92.300 ± 1.633

save_table

Saves a formatted table to a text file.
from laft.utils import save_table

save_table(table, "results/experiment_results.txt")

Parameters

table
str
required
The formatted table string to save (typically from build_table()).
path
str
required
File path where the table should be saved.Parent directories are created automatically if they don’t exist.

Examples

Save experiment results

from laft.utils import build_table, save_table

metrics = {
    "baseline": {
        "color_mnist/guide_number": {"auroc": 0.851, "auprc": 0.792},
    },
    "laft": {
        "color_mnist/guide_number": {"auroc": 0.923, "auprc": 0.891},
    },
}

table = build_table(metrics, group_headers=["Dataset", "Guidance"])
save_table(table, "results/color_mnist_results.txt")

print("Results saved to results/color_mnist_results.txt")

Save multiple experiment tables

import os
from datetime import datetime

# Create timestamped directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = f"results/{timestamp}"

# Save different metric tables
for metric in ["auroc", "auprc", "fpr95"]:
    table = build_table(metrics, types=[metric])
    save_table(table, os.path.join(results_dir, f"{metric}_results.txt"))

Compare with previous results

from laft.utils import build_table, save_table

# Current experiment
current_metrics = {...}
current_table = build_table(current_metrics)
save_table(current_table, "results/current.txt")

# Previous baseline
baseline_metrics = {...}
baseline_table = build_table(baseline_metrics)
save_table(baseline_table, "results/baseline.txt")

print("Tables saved:")
print("  - results/current.txt")
print("  - results/baseline.txt")

Complete Workflow Example

Here’s a complete example combining all utility functions:
from laft.utils import get_clip_cached_features, build_table, save_table
from laft.prompts import get_prompts, get_labels
import torch
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

def compute_fpr95(labels, scores):
    """Compute FPR at 95% TPR."""
    from sklearn.metrics import roc_curve
    fpr, tpr, _ = roc_curve(labels, scores)
    idx = np.argmin(np.abs(tpr - 0.95))
    return fpr[idx]

# 1. Load cached CLIP features
model, data = get_clip_cached_features(
    model_name="ViT-B-16",
    dataset_name="color_mnist",
    device="cuda",
    verbose=True
)

train_features, _ = data['train']
test_features, test_attrs = data['test']

# 2. Get prompts and encode them
prompts = get_prompts("color_mnist", "guide_number")
text_features = model.encode_text(prompts['exact'])

# 3. Compute anomaly scores
scores = torch.matmul(test_features, text_features.T).max(dim=1).values

# 4. Get labels
_, _, attend_labels, _ = get_labels(
    "color_mnist", test_attrs, "guide_number"
)

# 5. Compute metrics
metrics = {
    "laft": {
        "color_mnist/guide_number": {
            "auroc": roc_auc_score(attend_labels.cpu(), scores.cpu()),
            "auprc": average_precision_score(attend_labels.cpu(), scores.cpu()),
            "fpr95": compute_fpr95(attend_labels.cpu().numpy(), scores.cpu().numpy()),
        }
    }
}

# 6. Build and save results table
table = build_table(
    metrics,
    group_headers=["Dataset", "Guidance"],
    types=["auroc", "auprc", "fpr95"]
)

print(table)
save_table(table, "results/laft_results.txt")

Build docs developers (and LLMs) love