Skip to main content

Overview

A concept subspace is an orthonormal basis that captures the semantic direction of a concept in CLIP’s embedding space. LAFT constructs concept subspaces by:
  1. Encoding text prompts with CLIP’s text encoder
  2. Computing pairwise differences between embeddings (prompt_pair())
  3. Extracting principal components via PCA (pca())
  4. Optionally aligning vectors to ensure consistent direction (align_vectors())
The resulting subspace can then be used with inner() or orthogonal() projections to transform image features.

Why Pairwise Differences?

Simply averaging text embeddings loses important semantic information. Instead, LAFT uses pairwise differences to capture the semantic direction that separates concepts.

Intuition

Consider prompts for normal and anomalous states:
normal = ["flawless bottle", "perfect bottle", "unblemished bottle"]
anomaly = ["damaged bottle", "broken bottle", "defective bottle"]
The difference vectors:
"flawless bottle" - "damaged bottle"
"perfect bottle" - "broken bottle"
...
point in the semantic direction from “anomalous” to “normal”. These differences, when aggregated via PCA, form a robust concept subspace.

Mathematical Formulation

Given nn text embeddings {t1,,tn}Rd\{\mathbf{t}_1, \ldots, \mathbf{t}_n\} \in \mathbb{R}^d, compute all pairwise differences: D={titj:1i<jn}\mathbf{D} = \{\mathbf{t}_i - \mathbf{t}_j : 1 \leq i < j \leq n\} This produces (n2)=n(n1)2\binom{n}{2} = \frac{n(n-1)}{2} difference vectors.

The prompt_pair() Function

The prompt_pair() function computes pairwise differences and aligns them for consistent direction.

Implementation

From laft/laft.py:88-103:
def prompt_pair(*prompts_list: Tensor) -> Tensor:
    if len(prompts_list) == 1:
        prompts = prompts_list[0]
        length = prompts.size(0)
        idxs = torch.tensor([i * length + j for i in range(length) for j in range(i + 1, length)]).to(prompts.device)
        pairwise_diff = (prompts.unsqueeze(1) - prompts.unsqueeze(0)).flatten(0, 1).index_select(0, idxs)
        pairwise_diff = align_vectors(pairwise_diff)
        return pairwise_diff
    else:
        pairwise_diff = torch.cat([
            (prompts_list[i].unsqueeze(1) - prompts_list[j].unsqueeze(0)).flatten(0, 1)
            for i in range(len(prompts_list))
            for j in range(i + 1, len(prompts_list))
        ])
        pairwise_diff = align_vectors(pairwise_diff)
        return pairwise_diff

Single Prompt Group

When called with a single tensor of embeddings:
import laft

# Encode multiple prompts
prompts = [
    "a photo of a waterbird",
    "a photo of a landbird",
    "a photo of a seagull",
    "a photo of a sparrow",
]

text_features = model.encode_text(prompts)  # [4, 512]

# Compute all pairwise differences
pairs = laft.prompt_pair(text_features)  # [6, 512] = C(4,2)
This computes (4×3)/2=6(4 \times 3) / 2 = 6 pairwise differences:
  • prompts[0] - prompts[1]
  • prompts[0] - prompts[2]
  • prompts[0] - prompts[3]
  • prompts[1] - prompts[2]
  • prompts[1] - prompts[3]
  • prompts[2] - prompts[3]

Multiple Prompt Groups

When called with multiple groups (e.g., normal vs. anomaly):
# Separate normal and anomaly prompts
normal_prompts = ["flawless bottle", "perfect bottle"]
anomaly_prompts = ["damaged bottle", "broken bottle"]

normal_features = model.encode_text(normal_prompts)    # [2, 512]
anomaly_features = model.encode_text(anomaly_prompts)  # [2, 512]

# Compute cross-group pairwise differences
pairs = laft.prompt_pair(normal_features, anomaly_features)  # [4, 512]
This computes all differences between groups:
  • normal[0] - anomaly[0]
  • normal[0] - anomaly[1]
  • normal[1] - anomaly[0]
  • normal[1] - anomaly[1]
With multiple groups, prompt_pair() computes differences between groups, not within. This is useful for capturing the semantic direction that separates distinct concepts.

Parameters

prompts_list
Tensor
required
One or more tensors of text embeddings:
  • Single tensor [n, d]: Computes all (n2)\binom{n}{2} pairwise differences
  • Multiple tensors: Computes all cross-group differences
Returns
Tensor
Aligned pairwise difference vectors. Shape: [num_pairs, feature_size]

Vector Alignment

Difference vectors can point in opposite directions (e.g., a - b vs. b - a). The align_vectors() function flips vectors to ensure they point in a consistent direction relative to a reference.

Implementation

From laft/laft.py:81-85:
def align_vectors(vectors: torch.Tensor, reference_idx: int = 0) -> Tensor:
    reference = vectors[reference_idx].half()
    sim = F.cosine_similarity(vectors.half(), reference)
    aligned = torch.where((sim < 0).unsqueeze(dim=1), -vectors, vectors)
    return aligned

How It Works

  1. Use the first vector (index 0) as the reference direction
  2. Compute cosine similarity of each vector with the reference
  3. If similarity is negative (opposite direction), flip the vector
  4. Return aligned vectors
Example:
import torch
import laft

vectors = torch.tensor([
    [1.0, 0.0],   # Reference
    [0.5, 0.5],   # Similar direction
    [-0.8, 0.1],  # Opposite direction (will be flipped)
])

aligned = laft.align_vectors(vectors)
# Result:
# [[1.0, 0.0],
#  [0.5, 0.5],
#  [0.8, -0.1]]  # Flipped to align with reference
Alignment is crucial for PCA to work effectively. Without alignment, opposite-pointing vectors would cancel out during dimensionality reduction.

The pca() Function

After computing pairwise differences, PCA extracts the principal components that capture the most variance in the semantic direction.

Implementation

From laft/laft.py:68-78:
def pca(
    vectors: Tensor,
    n_components: int | None = None,
    *,
    center: bool = False,
    niter: int = 5,
) -> Tensor:
    min_d = min(vectors.size(0), vectors.size(1))
    d = min_d if n_components is None else min(n_components, min_d)
    components = torch.pca_lowrank(vectors, q=d, center=center, niter=niter)[2].T
    return components

Parameters

vectors
Tensor
required
Input vectors (typically from prompt_pair()). Shape: [num_vectors, feature_size]
n_components
int
default:"None"
Number of principal components to extract. If None, extracts all possible components (up to min(num_vectors, feature_size)).
center
bool
default:"False"
If True, centers the vectors by subtracting their mean before PCA. LAFT uses False because difference vectors are already approximately centered.
niter
int
default:"5"
Number of iterations for the randomized SVD algorithm. Higher values are more accurate but slower.
Returns
Tensor
Orthonormal basis vectors (principal components). Shape: [n_components, feature_size]Each row is a unit vector, and rows are mutually orthogonal.

Usage Example

import laft

# Load CLIP model
model, transform = laft.load_clip("ViT-B-16-quickgelu:dfn2b")

# Create comprehensive prompt set
prompts = [
    "a photo of a flawless bottle",
    "a photo of a perfect bottle",
    "a photo of an unblemished bottle",
    "a photo of a damaged bottle",
    "a photo of a bottle with defect",
    "a photo of a broken bottle",
]

# Encode prompts
text_features = model.encode_text(prompts)  # [6, 512]

# Compute pairwise differences
pairs = laft.prompt_pair(text_features)  # [15, 512] = C(6,2)

# Extract top 24 principal components
concept_basis = laft.pca(pairs, n_components=24)  # [24, 512]

# Verify orthonormality
import torch
identity = concept_basis @ concept_basis.T
print(torch.allclose(identity, torch.eye(24), atol=1e-5))  # True
The first few principal components typically capture the most salient semantic directions. Start with n_components=24 as a reasonable default, then tune based on your task.

Complete Workflow Example

Here’s a complete example from the Waterbirds dataset (laft/prompts/waterbirds.py):
import laft
import torch

torch.set_grad_enabled(False)

# Load CLIP and dataset
model, data = laft.get_clip_cached_features(
    "ViT-B-16-quickgelu:dfn2b",
    "waterbirds",
    splits=["train", "test"]
)

train_features, _ = data["train"]
test_features, test_attrs = data["test"]

# Define prompts for bird species
WATER_BIRDS = ["seagull", "pelican", "tern", "cormorant"]
LAND_BIRDS = ["sparrow", "robin", "cardinal", "finch"]

# Create templated prompts
templates = [
    "a photo of a {}.",
    "a photo of a {}, a type of bird.",
    "a blurry photo of a {}.",
]

water_prompts = [[t.format(bird) for t in templates] for bird in WATER_BIRDS]
land_prompts = [[t.format(bird) for t in templates] for bird in LAND_BIRDS]

# Encode prompts (with ensemble averaging)
water_features = model.encode_text(water_prompts)  # [4, 512]
land_features = model.encode_text(land_prompts)    # [4, 512]

# Compute pairwise differences between water and land birds
all_features = torch.cat([water_features, land_features])  # [8, 512]
pairs = laft.prompt_pair(all_features)  # [28, 512] = C(8,2)

# Extract concept basis
concept_basis = laft.pca(pairs, n_components=32)  # [32, 512]

# Transform features to guide toward bird type
guided_train = laft.inner(train_features, concept_basis)
guided_test = laft.inner(test_features, concept_basis)

# Compute anomaly scores
scores = laft.knn(guided_train, guided_test, n_neighbors=30)

Prompt Design Best Practices

Combine multiple phrasings to create robust embeddings:
templates = [
    "a photo of a {}.",
    "a cropped photo of the {}.",
    "a blurry photo of a {}.",
    "a good photo of the {}.",
    "a bad photo of a {}.",
]
The LAFT encode_text() wrapper automatically averages embeddings when you pass lists of lists.
For industrial defect detection, vary both the object and the state:
NORMAL_STATES = [
    "{}",
    "flawless {}",
    "perfect {}",
    "unblemished {}",
    "{} without flaw",
    "{} without defect",
]

ANOMALY_STATES = [
    "damaged {}",
    "{} with flaw",
    "{} with defect",
    "{} with damage",
]
From laft/prompts/industrial1.py:24-38
Include comparable numbers of normal and anomalous descriptions:
# Good: Balanced
normal = ["perfect item", "flawless item", "good item"]
anomaly = ["damaged item", "defective item", "broken item"]

# Avoid: Imbalanced
normal = ["perfect item"]
anomaly = ["damaged", "broken", "cracked", "scratched", "dented", ...]  # Too many
For semantic datasets like Waterbirds, use actual species names:
WATER_BIRD_WORDS = [
    "Black footed Albatross",
    "Laysan Albatross",
    "Sooty Albatross",
    # ... 70+ water bird species
]

LAND_BIRD_WORDS = [
    "Groove billed Ani",
    "Brewer Blackbird",
    "Red winged Blackbird",
    # ... 200+ land bird species
]
From laft/prompts/waterbirds.py:72-277

Handling Multiple Concepts

For complex scenarios, you can construct separate subspaces and combine them:
# Concept 1: Bird type
bird_prompts = model.encode_text(bird_descriptions)
bird_pairs = laft.prompt_pair(bird_prompts)
bird_basis = laft.pca(bird_pairs, n_components=16)

# Concept 2: Background
back_prompts = model.encode_text(background_descriptions)
back_pairs = laft.prompt_pair(back_prompts)
back_basis = laft.pca(back_pairs, n_components=16)

# Option 1: Apply transformations sequentially
features_1 = laft.inner(features, bird_basis)       # Guide bird
features_2 = laft.orthogonal(features_1, back_basis) # Ignore background

# Option 2: Combine bases
combined_basis = torch.cat([bird_basis, back_basis])  # [32, 512]
features_combined = laft.inner(features, combined_basis)
Sequential transformations are not commutative: inner(orthogonal(f, B1), B2)orthogonal(inner(f, B2), B1). Choose the order based on which concept should be processed first.

Subspace Dimensionality Selection

The number of components significantly impacts performance:
import matplotlib.pyplot as plt

# Sweep over component counts
results = []
for k in range(2, 100):
    basis_k = concept_basis[:k]  # Use first k components
    transformed = laft.inner(test_features, basis_k)
    scores = laft.knn(train_features, transformed)
    auroc = laft.binary_auroc(scores, labels)
    results.append((k, auroc))

# Plot
ks, aurocs = zip(*results)
plt.plot(ks, aurocs)
plt.xlabel('Number of Components')
plt.ylabel('AUROC')
plt.title('Performance vs. Subspace Dimensionality')
From scripts/semantic/laft.py:50, the script iterates n_components from 2 to 384 to find the optimal value.
Rule of thumb: Start with kdk \approx \sqrt{d} where dd is the feature dimensionality. For CLIP ViT-B/16 (d=512d=512), try k[16,32]k \in [16, 32].

Visualizing the Concept Subspace

You can project the concept basis back into text space to understand what it represents:
import torch
import laft

# Get concept basis
concept_basis = laft.pca(pairs, n_components=5)

# Define candidate words
candidates = [
    "bird", "animal", "feather", "beak", "wing",
    "water", "ocean", "lake", "land", "forest",
    "damage", "broken", "defect", "perfect", "flawless",
]

# Encode candidates
candidate_features = model.encode_text(candidates)

# Compute projection magnitudes for each basis vector
for i, basis_vec in enumerate(concept_basis[:5]):
    scores = candidate_features @ basis_vec
    top_idx = scores.argsort(descending=True)[:3]
    print(f"Component {i}: {[candidates[j] for j in top_idx]}")
This helps verify that the subspace captures the intended semantic concept.

Troubleshooting

Problem: Only a few prompts result in very few pairwise differences.Solution: Increase prompt diversity. With nn prompts, you get (n2)\binom{n}{2} pairs. Aim for at least 20-30 pairs.
# Too few: 3 prompts → 3 pairs
prompts = ["cat", "dog", "bird"]

# Better: 8 prompts → 28 pairs
prompts = ["cat", "dog", "bird", "fish", "rabbit", "hamster", "turtle", "snake"]
Problem: Later principal components seem random or noisy.Solution: This is expected. Use fewer components (n_components=24 instead of 100+).
# Extract top components only
concept_basis = laft.pca(pairs, n_components=24)
Problem: Some difference vectors point in opposite directions, causing PCA to fail.Solution: prompt_pair() automatically calls align_vectors(). If you compute differences manually, align them:
diffs = prompt_a - prompt_b
aligned_diffs = laft.align_vectors(diffs)
concept_basis = laft.pca(aligned_diffs)

Advanced: Custom Subspace Construction

For specialized use cases, you can construct subspaces manually:
import torch
import laft

# Option 1: Single semantic direction
positive = model.encode_text("a photo of a perfect product")
negative = model.encode_text("a photo of a damaged product")
direction = positive - negative

# Use single vector (will be automatically unsqueezed)
guided = laft.inner(features, direction)

# Option 2: Manually construct basis from domain knowledge
import torch.nn.functional as F

v1 = model.encode_text("defect") - model.encode_text("normal")
v2 = model.encode_text("scratch") - model.encode_text("smooth")
v3 = model.encode_text("crack") - model.encode_text("intact")

# Orthogonalize using QR decomposition
vectors = torch.stack([v1, v2, v3])
Q, R = torch.linalg.qr(vectors.T)
custom_basis = Q.T  # [3, 512] orthonormal basis

# Use with projections
guided = laft.inner(features, custom_basis)

See Also

Feature Transformation

Learn how to use the concept subspace with inner() and orthogonal() projections

Overview

Understand the full LAFT methodology and workflow

Build docs developers (and LLMs) love