A concept subspace is an orthonormal basis that captures the semantic direction of a concept in CLIP’s embedding space. LAFT constructs concept subspaces by:
Encoding text prompts with CLIP’s text encoder
Computing pairwise differences between embeddings (prompt_pair())
Extracting principal components via PCA (pca())
Optionally aligning vectors to ensure consistent direction (align_vectors())
The resulting subspace can then be used with inner() or orthogonal() projections to transform image features.
Simply averaging text embeddings loses important semantic information. Instead, LAFT uses pairwise differences to capture the semantic direction that separates concepts.
import laft# Encode multiple promptsprompts = [ "a photo of a waterbird", "a photo of a landbird", "a photo of a seagull", "a photo of a sparrow",]text_features = model.encode_text(prompts) # [4, 512]# Compute all pairwise differencespairs = laft.prompt_pair(text_features) # [6, 512] = C(4,2)
With multiple groups, prompt_pair() computes differences between groups, not within. This is useful for capturing the semantic direction that separates distinct concepts.
Difference vectors can point in opposite directions (e.g., a - b vs. b - a). The align_vectors() function flips vectors to ensure they point in a consistent direction relative to a reference.
import laft# Load CLIP modelmodel, transform = laft.load_clip("ViT-B-16-quickgelu:dfn2b")# Create comprehensive prompt setprompts = [ "a photo of a flawless bottle", "a photo of a perfect bottle", "a photo of an unblemished bottle", "a photo of a damaged bottle", "a photo of a bottle with defect", "a photo of a broken bottle",]# Encode promptstext_features = model.encode_text(prompts) # [6, 512]# Compute pairwise differencespairs = laft.prompt_pair(text_features) # [15, 512] = C(6,2)# Extract top 24 principal componentsconcept_basis = laft.pca(pairs, n_components=24) # [24, 512]# Verify orthonormalityimport torchidentity = concept_basis @ concept_basis.Tprint(torch.allclose(identity, torch.eye(24), atol=1e-5)) # True
The first few principal components typically capture the most salient semantic directions. Start with n_components=24 as a reasonable default, then tune based on your task.
Here’s a complete example from the Waterbirds dataset (laft/prompts/waterbirds.py):
import laftimport torchtorch.set_grad_enabled(False)# Load CLIP and datasetmodel, data = laft.get_clip_cached_features( "ViT-B-16-quickgelu:dfn2b", "waterbirds", splits=["train", "test"])train_features, _ = data["train"]test_features, test_attrs = data["test"]# Define prompts for bird speciesWATER_BIRDS = ["seagull", "pelican", "tern", "cormorant"]LAND_BIRDS = ["sparrow", "robin", "cardinal", "finch"]# Create templated promptstemplates = [ "a photo of a {}.", "a photo of a {}, a type of bird.", "a blurry photo of a {}.",]water_prompts = [[t.format(bird) for t in templates] for bird in WATER_BIRDS]land_prompts = [[t.format(bird) for t in templates] for bird in LAND_BIRDS]# Encode prompts (with ensemble averaging)water_features = model.encode_text(water_prompts) # [4, 512]land_features = model.encode_text(land_prompts) # [4, 512]# Compute pairwise differences between water and land birdsall_features = torch.cat([water_features, land_features]) # [8, 512]pairs = laft.prompt_pair(all_features) # [28, 512] = C(8,2)# Extract concept basisconcept_basis = laft.pca(pairs, n_components=32) # [32, 512]# Transform features to guide toward bird typeguided_train = laft.inner(train_features, concept_basis)guided_test = laft.inner(test_features, concept_basis)# Compute anomaly scoresscores = laft.knn(guided_train, guided_test, n_neighbors=30)
Sequential transformations are not commutative: inner(orthogonal(f, B1), B2) ≠ orthogonal(inner(f, B2), B1). Choose the order based on which concept should be processed first.
Problem: Only a few prompts result in very few pairwise differences.Solution: Increase prompt diversity. With n prompts, you get (2n) pairs. Aim for at least 20-30 pairs.
Problem: Later principal components seem random or noisy.Solution: This is expected. Use fewer components (n_components=24 instead of 100+).
# Extract top components onlyconcept_basis = laft.pca(pairs, n_components=24)
Opposite-Direction Vectors
Problem: Some difference vectors point in opposite directions, causing PCA to fail.Solution: prompt_pair() automatically calls align_vectors(). If you compute differences manually, align them: