get_dataset
Loads semantic anomaly datasets with automatic train subset filtering.
from laft.utils import get_dataset
data = get_dataset(
dataset_name="color_mnist",
dataset_config=None,
dataset_root="./data"
)
Parameters
dataset_name
Literal['color_mnist', 'waterbirds', 'celeba']
required
Name of the semantic anomaly dataset to load.
dataset_config
dict | None
default:"None"
Dataset-specific configuration. When None, uses default configuration.
dataset_kwargs
dict | None
default:"None"
Additional keyword arguments passed to dataset constructor.
transform
Callable | None
default:"None"
Transform to apply to images.
Root directory for dataset files.
splits
Sequence[Literal['train', 'valid', 'test']]
default:"('train', 'test')"
Dataset splits to load.
Whether to print progress information.
Function to use for printing (defaults to built-in print).
Returns
Dictionary mapping split names to (subset, attrs) tuples.
- For
train split: subset contains only normal samples, attrs are all False
- For other splits: subset is the full dataset, attrs are the original attributes
Examples
Load train and test splits
from laft.utils import get_dataset
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
data = get_dataset(
dataset_name="color_mnist",
transform=transform,
dataset_root="./data"
)
# Access splits
train_subset, train_attrs = data['train']
test_subset, test_attrs = data['test']
print(f"Train samples (normal only): {len(train_subset)}")
print(f"Test samples (all): {len(test_subset)}")
Load all splits with custom config
custom_config = {
"number": {i: i >= 5 for i in range(10)},
"color": {"red": False, "green": True, "blue": True}
}
data = get_dataset(
dataset_name="color_mnist",
dataset_config=custom_config,
splits=["train", "valid", "test"],
verbose=True
)
for split, (subset, attrs) in data.items():
print(f"{split}: {len(subset)} samples")
Use with DataLoader
from torch.utils.data import DataLoader
data = get_dataset("color_mnist", transform=transform)
train_subset, _ = data['train']
train_loader = DataLoader(
train_subset,
batch_size=32,
shuffle=True,
num_workers=4
)
for images, targets in train_loader:
print(f"Batch shape: {images.shape}")
break
get_clip_cached_features
Computes and caches CLIP features for datasets, enabling fast repeated experiments.
from laft.utils import get_clip_cached_features
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
device="cuda"
)
Parameters
Name of the CLIP model to use (e.g., "ViT-B-16", "ViT-L-14").
dataset_name
Literal['color_mnist', 'waterbirds', 'celeba']
required
Name of the dataset.
splits
Sequence[Literal['train', 'train-all', 'valid', 'test']]
default:"('train', 'test')"
Dataset splits to compute features for.
train: Only normal samples from training set
train-all: All samples from training set
valid: Validation set
test: Test set
device
str | torch.device
default:"'cuda' if available else 'cpu'"
Device to run model on.
model_root
str | None
default:"./checkpoints/open_clip"
Directory to download/load CLIP model from.
Root directory for dataset files.
dataset_config
dict | None
default:"None"
Dataset-specific configuration.
dataset_kwargs
dict | None
default:"None"
Additional keyword arguments for dataset.
Whether to print progress information.
Function to use for printing.
Directory to store cached features.
Whether to recompute features even if cache exists.
Returns
data
dict[str, tuple[torch.Tensor, torch.Tensor]]
Dictionary mapping split names to (features, attrs) tuples.
features: Tensor of shape [num_samples, feature_dim] with CLIP image embeddings
attrs: Tensor of shape [num_samples, num_attrs] with attribute labels
Caching System
The function uses MD5 hashing of configuration to create unique cache keys:
keyset = {
"model_name": "ViT-B-16",
"dataset_name": "color_mnist",
"dataset_config": {...}, # Only for color_mnist
# ... other dataset_kwargs
}
hashkey = md5(json.dumps(keyset, sort_keys=True)).hexdigest()
# Features saved to: .cache/{hashkey}/train.pt, .cache/{hashkey}/test.pt
Examples
Basic usage with caching
from laft.utils import get_clip_cached_features
# First run: computes features and saves to cache
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
device="cuda",
verbose=True
)
# Second run: loads from cache (much faster)
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
device="cuda",
verbose=True
)
# Output: Loading train from cache '...'
# Loading test from cache '...'
# Access cached features
train_features, train_attrs = data['train']
test_features, test_attrs = data['test']
print(f"Train features shape: {train_features.shape}") # [N, 512] for ViT-B-16
print(f"Test features shape: {test_features.shape}")
Load all splits including train-all
model, data = get_clip_cached_features(
model_name="ViT-L-14",
dataset_name="color_mnist",
splits=["train", "train-all", "test"],
device="cuda"
)
# train: only normal samples
# train-all: all training samples (normal + anomaly)
train_normal, _ = data['train']
train_all, _ = data['train-all']
test, _ = data['test']
print(f"Normal train samples: {len(train_normal)}")
print(f"All train samples: {len(train_all)}")
print(f"Test samples: {len(test)}")
Force recompute with flush
# Recompute features even if cache exists
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
flush=True, # Ignore cache
verbose=True
)
Use custom dataset config
custom_config = {
"number": {i: i >= 7 for i in range(10)},
"color": {"red": False, "green": True, "blue": True}
}
# Different config = different cache key
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
dataset_config=custom_config,
device="cuda"
)
Use features for anomaly detection
import torch
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist"
)
train_features, _ = data['train']
test_features, test_attrs = data['test']
# Compute distances to training set
mean_train = train_features.mean(dim=0)
distances = torch.cdist(test_features, mean_train.unsqueeze(0))
# Evaluate anomaly detection
from sklearn.metrics import roc_auc_score
scores = distances.squeeze().cpu().numpy()
labels = test_attrs.any(dim=1).cpu().numpy() # Any anomaly attribute = anomaly
auroc = roc_auc_score(labels, scores)
print(f"AUROC: {auroc:.4f}")
build_table
Builds a formatted table from experiment metrics for easy comparison.
from laft.utils import build_table
table = build_table(
metrics=results,
types=["auroc", "auprc", "fpr95"]
)
Parameters
metrics
Mapping[str, Mapping[str, Mapping[str, float] | Sequence[Mapping[str, float]]]]
required
Nested dictionary of metrics.Structure: {method: {experiment: {metric: value} or [{metric: value}, ...]}}
- First level: method/model names (e.g., “baseline”, “laft”)
- Second level: experiment configurations (e.g., “color_mnist/guide_number”)
- Third level: metric values (dict for single run, list of dicts for multiple runs)
Column headers for grouping (e.g., ["Dataset", "Guidance"]).Length must match the number of levels in experiment names (split by /).
Headers for method names. Defaults to capitalized method names.
types
Sequence[str]
default:"('auroc', 'auprc', 'fpr95')"
Metric types to include in the table.Supported: "auroc", "auprc", "accuracy", "f1", "fpr95"
Format string for mean values (e.g., "5.1f" → " 95.3").
Format string for standard deviation values (e.g., "3.1f" → "2.1").
Returns
Formatted table string with metrics.
Examples
Single run metrics
from laft.utils import build_table
metrics = {
"baseline": {
"color_mnist/guide_number": {"auroc": 0.851, "auprc": 0.792, "fpr95": 0.423},
"color_mnist/guide_color": {"auroc": 0.823, "auprc": 0.765, "fpr95": 0.487},
},
"laft": {
"color_mnist/guide_number": {"auroc": 0.923, "auprc": 0.891, "fpr95": 0.231},
"color_mnist/guide_color": {"auroc": 0.897, "auprc": 0.854, "fpr95": 0.298},
},
}
table = build_table(
metrics,
group_headers=["Dataset", "Guidance"],
types=["auroc", "auprc", "fpr95"]
)
print(table)
Output:
Baseline Laft
Dataset Guidance AUROC AUPRC FPR95 AUROC AUPRC FPR95
color_mnist guide_number 85.1 79.2 42.3 92.3 89.1 23.1
color_mnist guide_color 82.3 76.5 48.7 89.7 85.4 29.8
Multiple runs with mean ± std
metrics = {
"laft": {
"color_mnist/guide_number": [
{"auroc": 0.921, "auprc": 0.889, "fpr95": 0.235},
{"auroc": 0.925, "auprc": 0.893, "fpr95": 0.227},
{"auroc": 0.923, "auprc": 0.891, "fpr95": 0.231},
],
},
}
table = build_table(
metrics,
group_headers=["Dataset", "Guidance"],
types=["auroc", "auprc", "fpr95"],
meanfmt="5.1f",
stdfmt="3.1f"
)
print(table)
# Output: 92.3 ± 1.6 89.1 ± 1.6 23.1 ± 3.3
# High precision formatting
table = build_table(
metrics,
types=["auroc"],
meanfmt="6.3f", # More decimal places
stdfmt="5.3f"
)
print(table)
# Output: 92.300 ± 1.633
save_table
Saves a formatted table to a text file.
from laft.utils import save_table
save_table(table, "results/experiment_results.txt")
Parameters
The formatted table string to save (typically from build_table()).
File path where the table should be saved.Parent directories are created automatically if they don’t exist.
Examples
Save experiment results
from laft.utils import build_table, save_table
metrics = {
"baseline": {
"color_mnist/guide_number": {"auroc": 0.851, "auprc": 0.792},
},
"laft": {
"color_mnist/guide_number": {"auroc": 0.923, "auprc": 0.891},
},
}
table = build_table(metrics, group_headers=["Dataset", "Guidance"])
save_table(table, "results/color_mnist_results.txt")
print("Results saved to results/color_mnist_results.txt")
Save multiple experiment tables
import os
from datetime import datetime
# Create timestamped directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = f"results/{timestamp}"
# Save different metric tables
for metric in ["auroc", "auprc", "fpr95"]:
table = build_table(metrics, types=[metric])
save_table(table, os.path.join(results_dir, f"{metric}_results.txt"))
Compare with previous results
from laft.utils import build_table, save_table
# Current experiment
current_metrics = {...}
current_table = build_table(current_metrics)
save_table(current_table, "results/current.txt")
# Previous baseline
baseline_metrics = {...}
baseline_table = build_table(baseline_metrics)
save_table(baseline_table, "results/baseline.txt")
print("Tables saved:")
print(" - results/current.txt")
print(" - results/baseline.txt")
Complete Workflow Example
Here’s a complete example combining all utility functions:
from laft.utils import get_clip_cached_features, build_table, save_table
from laft.prompts import get_prompts, get_labels
import torch
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
def compute_fpr95(labels, scores):
"""Compute FPR at 95% TPR."""
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(labels, scores)
idx = np.argmin(np.abs(tpr - 0.95))
return fpr[idx]
# 1. Load cached CLIP features
model, data = get_clip_cached_features(
model_name="ViT-B-16",
dataset_name="color_mnist",
device="cuda",
verbose=True
)
train_features, _ = data['train']
test_features, test_attrs = data['test']
# 2. Get prompts and encode them
prompts = get_prompts("color_mnist", "guide_number")
text_features = model.encode_text(prompts['exact'])
# 3. Compute anomaly scores
scores = torch.matmul(test_features, text_features.T).max(dim=1).values
# 4. Get labels
_, _, attend_labels, _ = get_labels(
"color_mnist", test_attrs, "guide_number"
)
# 5. Compute metrics
metrics = {
"laft": {
"color_mnist/guide_number": {
"auroc": roc_auc_score(attend_labels.cpu(), scores.cpu()),
"auprc": average_precision_score(attend_labels.cpu(), scores.cpu()),
"fpr95": compute_fpr95(attend_labels.cpu().numpy(), scores.cpu().numpy()),
}
}
}
# 6. Build and save results table
table = build_table(
metrics,
group_headers=["Dataset", "Guidance"],
types=["auroc", "auprc", "fpr95"]
)
print(table)
save_table(table, "results/laft_results.txt")