Skip to main content

Prompts

Prompts are the foundation of LAFT’s language-assisted feature transformation. They define the semantic concepts that guide or are ignored during feature transformation.

Overview

LAFT uses text prompts to:
  1. Define concepts: Normal vs. anomalous states or semantic attributes
  2. Build subspaces: Compute pairwise differences between prompt embeddings
  3. Guide transformation: Project features onto or away from concept subspaces

Prompt System API

The laft.prompts module provides three main functions:

get_prompts(dataset_name, guidance)

Returns a dictionary of prompt templates for a dataset.
import laft

# Get prompts for Color MNIST, guiding towards "number"
prompts = laft.prompts.get_prompts("color_mnist", "guide_number")

# Available prompt types:
# - "normal": Prompts for normal class
# - "anomaly": Prompts for anomaly class
# - "half": Normal + half of anomaly prompts
# - "exact": Normal + exact anomaly prompts
# - "all": Normal + anomaly + auxiliary prompts

print(prompts.keys())  # dict_keys(['normal', 'anomaly', 'half', 'exact', 'all'])
Parameters:
  • dataset_name (str): One of "color_mnist", "waterbirds", "celeba"
  • guidance (str): Strategy in format "guide_{concept}" or "ignore_{concept}"
Returns:
  • Dictionary mapping prompt types to lists of prompt strings

get_labels(dataset_name, attrs, guidance)

Extracts anomaly labels based on guidance strategy.
import laft

# Load test attributes
model, data = laft.get_clip_cached_features(
    "ViT-B-16-quickgelu:dfn2b",
    "color_mnist",
    splits=["test"]
)
_, test_attrs = data["test"]

# Get labels for "guide number" task
attend_name, ignore_name, attend_labels, ignore_labels = \
    laft.prompts.get_labels("color_mnist", test_attrs, "guide_number")

print(attend_name)    # "number"
print(ignore_name)    # "color"
print(attend_labels.shape)  # [N] - binary labels for number anomaly
print(ignore_labels.shape)  # [N] - binary labels for color anomaly
Parameters:
  • dataset_name (str): Dataset identifier
  • attrs (Tensor): Attribute tensor from dataset [N, num_attributes]
  • guidance (str): Guidance strategy
Returns:
  • Tuple of (attend_name, ignore_name, attend_labels, ignore_labels)

get_words(dataset_name, guidance)

Returns raw word lists (without templates) for a dataset.
import laft

words = laft.prompts.get_words("color_mnist", "guide_number")
print(words["normal"])  # ['zero', '0', 'one', '1', ...]
print(words["anomaly"]) # ['five', '5', 'six', '6', ...]

Available Datasets

Color MNIST

Detect anomalies in the digit class (5-9 are anomalies).
prompts = laft.prompts.get_prompts("color_mnist", "guide_number")

# Normal: digits 0-4
# Anomaly: digits 5-9
# Templates: "a photo of {}", "an image of {}", etc.
Example prompts:
  • Normal: “a photo of zero”, “a number 1”, “a sketch of 2 letter”
  • Anomaly: “a photo of five”, “a number 6”, “a sketch of 7 letter”

Waterbirds

Detect landbirds (anomalies) vs. waterbirds (normal).
prompts = laft.prompts.get_prompts("waterbirds", "guide_bird")

# Normal: waterbird species
# Anomaly: landbird species
Example prompts:
  • Normal: “a photo of a Black footed Albatross”, “a blurry photo of a Laysan Albatross, a type of bird”
  • Anomaly: “a photo of a Groove billed Ani”, “a good photo of a Brewer Blackbird, a type of bird”

CelebA

Detect blond hair (anomaly) vs. non-blond (normal).
prompts = laft.prompts.get_prompts("celeba", "guide_blond")

Industrial Prompts

For industrial anomaly detection (MVTec AD, VisA), use the industrial prompt modules:
from laft.prompts import industrial1

# Get normal and anomaly prompts for a category
normal_prompts, anomaly_prompts = industrial1.get_prompts("bottle")

print(normal_prompts[:3])
# ['bottle', 'flawless bottle', 'perfect bottle']

print(anomaly_prompts[:3])
# ['damaged bottle', 'bottle with flaw', 'bottle with defect']
Industrial templates:
NORMAL_STATES = [
    "{}",
    "flawless {}",
    "perfect {}",
    "unblemished {}",
    "{} without flaw",
    "{} without defect",
    "{} without damage",
]

Guide vs. Ignore Strategies

The guidance string determines the transformation type:
# Format: "guide_{concept}"
# Projects features ONTO concept subspace
# Use: Detect anomalies based on a specific concept

prompts = laft.prompts.get_prompts("color_mnist", "guide_number")
text_features = model.encode_text(prompts["all"])
pair_diffs = laft.prompt_pair(text_features)
concept_basis = laft.pca(pair_diffs)

# Inner projection: keep only concept-related information
guided_features = laft.inner(image_features, concept_basis)

Prompt Types Explained

1

Normal

Contains only prompts for the normal class.
prompts["normal"]  # Minimal set, fastest
2

Exact

Contains prompts for both normal and exact anomaly classes.
prompts["exact"]  # Normal + exact anomaly set
3

Half

Contains normal prompts plus half of the anomaly prompts.
prompts["half"]  # Normal + 50% anomaly
4

All

Contains normal, anomaly, and auxiliary prompts.
prompts["all"]  # Comprehensive, best performance

Creating Custom Prompts

You can define custom prompts for new datasets:
# Define your templates
TEMPLATES = [
    "a photo of a {}.",
    "a blurry photo of a {}.",
    "a high contrast photo of a {}.",
]

# Define concept words
NORMAL_WORDS = ["apple", "orange", "banana"]
ANOMALY_WORDS = ["rotten apple", "moldy orange", "bruised banana"]

# Generate prompts
normal_prompts = [[f.format(word) for f in TEMPLATES] for word in NORMAL_WORDS]
anomaly_prompts = [[f.format(word) for f in TEMPLATES] for word in ANOMALY_WORDS]

# Flatten for "ignore" strategy
if guidance.startswith("ignore"):
    normal_prompts = [p for sublist in normal_prompts for p in sublist]
    anomaly_prompts = [p for sublist in anomaly_prompts for p in sublist]

# Encode with CLIP
all_prompts = normal_prompts + anomaly_prompts
text_features = model.encode_text(all_prompts)

# Build concept subspace
pair_diffs = laft.prompt_pair(text_features)
concept_basis = laft.pca(pair_diffs)

Prompt Pair Construction

The prompt_pair() function computes pairwise differences:
import laft

# Single set of prompts: compute all pairwise differences
text_features = model.encode_text(prompts["all"])  # [N, D]
pair_diffs = laft.prompt_pair(text_features)  # [(N*(N-1))/2, D]

# Multiple sets: compute differences between sets
normal_features = model.encode_text(normal_prompts)  # [N1, D]
anomaly_features = model.encode_text(anomaly_prompts)  # [N2, D]
pair_diffs = laft.prompt_pair(normal_features, anomaly_features)  # [N1*N2, D]

# PCA on pairwise differences
concept_basis = laft.pca(pair_diffs, n_components=24)

Best Practices

Prompt Consistency: Ensure prompts are balanced between normal and anomaly classes to avoid bias in the concept subspace.
Template Diversity: Use diverse templates (e.g., “a photo of”, “a blurry photo of”) to improve robustness.
Auxiliary Prompts: Including auxiliary prompts (the “all” type) often improves performance by providing richer semantic information.

Example Workflows

import laft
import torch

torch.set_grad_enabled(False)

# Load model and data
model, data = laft.get_clip_cached_features(
    "ViT-B-16-quickgelu:dfn2b", "waterbirds", splits=["train", "test"]
)
train_features, _ = data["train"]
test_features, test_attrs = data["test"]

# Get prompts and labels
attend_name, _, attend_labels, _ = \
    laft.prompts.get_labels("waterbirds", test_attrs, "guide_bird")
prompts = laft.prompts.get_prompts("waterbirds", "guide_bird")

# Encode prompts
text_features = model.encode_text(prompts["all"])
pair_diffs = laft.prompt_pair(text_features)
concept_basis = laft.pca(pair_diffs, n_components=30)

# Transform features
train_guided = laft.inner(train_features, concept_basis)
test_guided = laft.inner(test_features, concept_basis)

# Evaluate
scores = laft.knn(train_guided, test_guided, n_neighbors=30)
metrics = laft.binary_metrics(scores, attend_labels)
print(f"AUROC: {metrics['auroc']:.3f}")
from laft.prompts import industrial1
import laft
import torch

torch.set_grad_enabled(False)

# Load model
model, transform = laft.load_clip("ViT-B-16-plus-240:laion400m_e31")

# Get industrial prompts
normal_prompts, anomaly_prompts = industrial1.get_prompts("bottle")

# Encode prompts
normal_features = model.encode_text(normal_prompts)
anomaly_features = model.encode_text(anomaly_prompts)

# Compute pairwise differences between normal and anomaly
pair_diffs = laft.prompt_pair(normal_features, anomaly_features)
concept_basis = laft.pca(pair_diffs, n_components=16)

# Continue with feature transformation...

Next Steps

Basic Usage

Learn the complete LAFT workflow

Evaluation

Understand metrics and evaluation strategies

Build docs developers (and LLMs) love