Skip to main content

Basic Usage

LAFT (Language-Assisted Feature Transformation) enables anomaly detection by transforming image features based on language-guided concept subspaces. This guide walks you through a complete workflow.

Overview

The LAFT workflow consists of these key steps:
1

Load data and model

Load your dataset and CLIP model with pre-computed features or raw images
2

Encode features

Extract image and text embeddings using the CLIP encoder
3

Build concept subspace

Generate prompt pairs and compute the principal components
4

Transform features

Project image features using inner() (guide) or orthogonal() (ignore)
5

Evaluate anomaly scores

Use k-NN distance in transformed space to compute anomaly scores

Semantic Anomaly Detection

For datasets with semantic attributes (e.g., Color MNIST, Waterbirds, CelebA), LAFT can guide or ignore specific concepts.

Complete Example: Color MNIST

import torch
import laft

# Disable gradients to prevent OOM
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

# Load model and cached features
model, data = laft.get_clip_cached_features(
    "ViT-B-16-quickgelu:dfn2b",
    "color_mnist",
    splits=["train", "test"],
    dataset_kwargs={"seed": 42}
)

# Extract features and attributes
train_features, _ = data["train"]
test_features, test_attrs = data["test"]

# Get labels and prompts for "guide number" task
attend_name, ignore_name, attend_labels, ignore_labels = \
    laft.prompts.get_labels("color_mnist", test_attrs, "guide_number")

prompts = laft.prompts.get_prompts("color_mnist", "guide_number")

# Encode prompts and build concept subspace
text_features = model.encode_text(prompts["all"])
pair_diffs = laft.prompt_pair(text_features)
concept_basis = laft.pca(pair_diffs, n_components=24)

# Transform features to guide towards number concept
train_guided = laft.inner(train_features, concept_basis)
test_guided = laft.inner(test_features, concept_basis)

# Compute anomaly scores using k-NN
scores = laft.knn(train_guided, test_guided, n_neighbors=30)

# Evaluate metrics
metrics = laft.binary_metrics(scores, attend_labels)
print(f"AUROC: {metrics['auroc']:.3f}")
print(f"AUPRC: {metrics['auprc']:.3f}")
print(f"FPR95: {metrics['fpr95']:.3f}")

Guide vs Ignore

LAFT supports two transformation strategies:
# Guide: Project features ONTO the concept subspace
# Use when you want to detect anomalies BASED ON a specific concept
guided_features = laft.inner(image_features, concept_basis)

# Example: Detect anomalies in the "number" attribute
prompts = laft.prompts.get_prompts("color_mnist", "guide_number")

Waterbirds Example

import laft
import torch

torch.set_grad_enabled(False)

# Load Waterbirds dataset
model, data = laft.get_clip_cached_features(
    "ViT-B-16-quickgelu:dfn2b",
    "waterbirds",
    splits=["train", "test"]
)

train_features, _ = data["train"]
test_features, test_attrs = data["test"]

# Guide towards "bird" type (landbirds are anomalies)
attend_name, ignore_name, attend_labels, ignore_labels = \
    laft.prompts.get_labels("waterbirds", test_attrs, "guide_bird")

prompts = laft.prompts.get_prompts("waterbirds", "guide_bird")
text_features = model.encode_text(prompts["all"])

# Build concept subspace
pair_diffs = laft.prompt_pair(text_features)
concept_basis = laft.pca(pair_diffs, n_components=50)

# Transform and evaluate
train_guided = laft.inner(train_features, concept_basis)
test_guided = laft.inner(test_features, concept_basis)
scores = laft.knn(train_guided, test_guided, n_neighbors=30)

metrics = laft.binary_metrics(scores, attend_labels)
print(f"Bird anomaly detection - AUROC: {metrics['auroc']:.3f}")

Industrial Anomaly Detection

For industrial defect detection (e.g., MVTec AD, VisA), LAFT works with normal/anomaly prompts.

Complete Example: MVTec AD

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from open_clip.transform import _convert_to_rgb
import laft
import baselines

torch.set_grad_enabled(False)

# Load WinCLIP model (for industrial AD)
model, transform = baselines.load_winclip("ViT-B-16-plus-240:laion400m_e31")

# Setup category-specific prompts
category = "bottle"
model.setup_prompts(class_name=category)

# Load dataset
train_dataset = laft.build_industrial_dataset(
    "mvtec", category, split="train", transform=transform
)
test_dataset = laft.build_industrial_dataset(
    "mvtec", category, split="test", transform=transform
)

test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)

# Few-shot setup: sample reference images
n_shot = 4
rng = torch.Generator().manual_seed(42)
idxs = torch.randperm(len(train_dataset), generator=rng)[:n_shot].tolist()
reference_images = torch.stack([train_dataset[i][0] for i in idxs]).cuda()
model.setup_images(reference_images)

# Inference
scores_list, labels_list = [], []
for images, masks, labels in test_loader:
    scores, heatmaps = model(images.cuda())
    scores_list.append(scores.cpu())
    labels_list.append(labels)

# Evaluate
all_scores = torch.cat(scores_list)
all_labels = torch.cat(labels_list)
metrics = laft.binary_metrics(all_scores, all_labels, types=["auroc"])
print(f"MVTec {category} - AUROC: {metrics['auroc']:.3f}")

Component Selection

The number of principal components affects performance. Typically:
  • Semantic datasets: 10-50 components work well
  • Industrial datasets: Experiment with 2-100 components
# Sweep over component counts
for n_components in range(2, 100):
    concept_basis = laft.pca(pair_diffs, n_components=n_components)
    train_guided = laft.inner(train_features, concept_basis[:n_components])
    test_guided = laft.inner(test_features, concept_basis[:n_components])
    scores = laft.knn(train_guided, test_guided, n_neighbors=30)
    metrics = laft.binary_metrics(scores, labels)
    print(f"Components: {n_components}, AUROC: {metrics['auroc']:.3f}")

Working with Raw Images

If you don’t have pre-computed features:
import laft
import torch
from PIL import Image

# Load model and transform
model, transform = laft.load_clip("ViT-B-16-quickgelu:dfn2b")

# Load and transform image
image = Image.open("path/to/image.jpg")
image_tensor = transform(image).unsqueeze(0).cuda()

# Encode image
with torch.inference_mode():
    image_features = model.encode_image(image_tensor)
    
# Continue with LAFT workflow...
text_features = model.encode_text(prompts)
# ... rest of the pipeline

Best Practices

Always disable gradients for inference:
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")
Use get_clip_cached_features() to cache extracted features and avoid recomputation:
model, data = laft.get_clip_cached_features(
    model_name="ViT-B-16-quickgelu:dfn2b",
    dataset_name="color_mnist",
    cache_root="./.cache",
    flush=False  # Set True to recompute
)
For large datasets, use DataLoader for efficient processing:
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=32, num_workers=4)

features_list = []
for batch in loader:
    features = model.encode_image(batch[0].cuda())
    features_list.append(features.cpu())

all_features = torch.cat(features_list)

Next Steps

Prompts

Learn about the prompt system and how to create custom prompts

Evaluation

Understand metrics and how to evaluate your models

Build docs developers (and LLMs) love