Skip to main content

get_prompts

Returns text prompts for a specific dataset and guidance type.
from laft.prompts import get_prompts

prompts = get_prompts(
    dataset_name="color_mnist",
    guidance="guide_number"
)

Parameters

dataset_name
str
required
Name of the dataset to get prompts for.Supported datasets:
  • color_mnist: Colored MNIST digits
  • waterbirds: Waterbirds dataset
  • celeba: CelebA dataset
guidance
str
required
Type of guidance to use. Can be prefixed with guide_ or ignore_.For color_mnist:
  • guide_number or number: Guide attention to digit numbers
  • guide_color or color: Guide attention to colors
  • ignore_number: Ignore number variations
  • ignore_color: Ignore color variations
The function automatically adds guide_ prefix if not present.

Returns

prompts
dict[str, list]
Dictionary mapping prompt categories to lists of text prompts.Keys:
  • normal: Prompts for normal/expected attributes
  • anomaly: Prompts for anomalous attributes
  • half: Prompts for normal + half of anomaly attributes
  • exact: Prompts for normal + all anomaly attributes
  • all: Prompts for normal + anomaly + auxiliary attributes
For guide_* modes, each value is a list of prompt lists (one per word). For ignore_* modes, each value is a flat list of prompts.

Examples

Get prompts for ColorMNIST with number guidance

from laft.prompts import get_prompts

prompts = get_prompts("color_mnist", "guide_number")

# Access different prompt categories
print(f"Normal prompts: {len(prompts['normal'])} word groups")
print(f"Anomaly prompts: {len(prompts['anomaly'])} word groups")

# Example: First normal prompt group
print(prompts['normal'][0])
# Output: ['zero', 'a number zero', 'an image of zero', ...]

# Example: Using exact prompts (normal + anomaly)
for word_prompts in prompts['exact']:
    print(f"{len(word_prompts)} variants for one word")

Get prompts for color guidance

prompts = get_prompts("color_mnist", "guide_color")

# Normal: red color variations
# Anomaly: green and blue color variations + auxiliary colors
print(prompts['normal'][0])  # ['red', 'a number red', ...]
print(prompts['anomaly'][0])  # ['green', 'a number green', ...]

Using ignore mode

# Ignore mode returns flattened prompt lists
prompts = get_prompts("color_mnist", "ignore_color")

# All prompts are flattened into single lists
print(type(prompts['normal']))  # list
print(prompts['normal'][:3])
# Output: ['zero', 'a number zero', 'an image of zero']

Working with different datasets

# Waterbirds dataset
waterbirds_prompts = get_prompts("waterbirds", "guide_background")

# CelebA dataset
celeba_prompts = get_prompts("celeba", "guide_hair_color")

get_labels

Extracts attend and ignore labels from attribute tensors based on guidance type.
from laft.prompts import get_labels

attend_name, ignore_name, attend_labels, ignore_labels = get_labels(
    dataset_name="color_mnist",
    attrs=dataset.attrs,
    guidance="guide_number"
)

Parameters

dataset_name
str
required
Name of the dataset.Supported: color_mnist, waterbirds, celeba
attrs
torch.Tensor
required
Attribute tensor from the dataset with shape [num_samples, num_attrs]. Boolean values where False = normal, True = anomaly.
guidance
str
required
Type of guidance (e.g., "guide_number", "ignore_color").

Returns

attend_name
str
Name of the attribute to attend to (e.g., "number" or "color").
ignore_name
str
Name of the attribute to ignore.
attend_labels
torch.Tensor
Boolean tensor of shape [num_samples] with labels for the attended attribute.
ignore_labels
torch.Tensor
Boolean tensor of shape [num_samples] with labels for the ignored attribute.

Examples

Extract labels for number guidance

from laft.datasets import build_semantic_dataset
from laft.prompts import get_labels

# Load dataset
dataset = build_semantic_dataset("color_mnist", "test")

# Get labels
attend_name, ignore_name, attend_labels, ignore_labels = get_labels(
    dataset_name="color_mnist",
    attrs=dataset.attrs,
    guidance="guide_number"
)

print(f"Attending to: {attend_name}")  # "number"
print(f"Ignoring: {ignore_name}")      # "color"
print(f"Attend labels shape: {attend_labels.shape}")  # [num_samples]
print(f"Number of anomalies in attended attribute: {attend_labels.sum()}")

Analyze label distribution

import torch

attend_name, ignore_name, attend_labels, ignore_labels = get_labels(
    "color_mnist", dataset.attrs, "guide_color"
)

# Count samples by label combination
normal_normal = (~attend_labels & ~ignore_labels).sum()
normal_anomaly = (~attend_labels & ignore_labels).sum()
anomaly_normal = (attend_labels & ~ignore_labels).sum()
anomaly_anomaly = (attend_labels & ignore_labels).sum()

print(f"Normal {attend_name}, Normal {ignore_name}: {normal_normal}")
print(f"Normal {attend_name}, Anomaly {ignore_name}: {normal_anomaly}")
print(f"Anomaly {attend_name}, Normal {ignore_name}: {anomaly_normal}")
print(f"Anomaly {attend_name}, Anomaly {ignore_name}: {anomaly_anomaly}")

get_words

Returns word lists for a specific dataset and guidance type (without templates).
from laft.prompts import get_words

words = get_words(
    dataset_name="color_mnist",
    guidance="guide_number"
)

Parameters

dataset_name
str
required
Name of the dataset.Supported: color_mnist, waterbirds, celeba
guidance
str
required
Type of guidance (e.g., "guide_number", "ignore_color").

Returns

words
dict[str, list]
Dictionary mapping word categories to lists of words.Keys: normal, anomaly, half, exact, allFor guide_* modes: values are lists of words. For ignore_* modes: values are flattened lists.

Examples

Get words for number guidance

from laft.prompts import get_words

words = get_words("color_mnist", "guide_number")

print(words['normal'])
# Output: ['zero', '0', 'one', '1', 'two', '2', 'three', '3', 'four', '4']

print(words['anomaly'])
# Output: ['five', '5', 'six', '6', 'seven', '7', 'eight', '8', 'nine', '9', 
#          'ten', '10', 'eleven', '11', ...]

print(words['exact'])
# Output: ['zero', '0', ..., 'four', '4', 'five', '5', ..., 'nine', '9']

Get words for color guidance

words = get_words("color_mnist", "guide_color")

print(words['normal'])
# Output: ['red', 'ruby', 'scarlet', 'crimson', 'maroon', 'carmine', 'vermilion']

print(words['anomaly'])
# Output: ['green', 'lime', 'olive', 'jade', 'blue', 'azure', 'sky blue', 'navy',
#          'yellow', 'gold', 'amber', ...]

Compare words vs prompts

from laft.prompts import get_words, get_prompts

words = get_words("color_mnist", "guide_number")
prompts = get_prompts("color_mnist", "guide_number")

# Words are the base terms
print(f"Number of normal words: {len(words['normal'])}")
# Output: 10 (zero, 0, one, 1, ...)

# Prompts apply templates to each word
print(f"Number of normal prompt groups: {len(prompts['normal'])}")
# Output: 10 (one group per word)

print(f"Prompts per word: {len(prompts['normal'][0])}")
# Output: 19 (number of templates)

Prompt Structure

The prompt system uses templates to generate variations of each word:
TEMPLATES = [
    "{}",
    "a number {}",
    "an image of {}",
    "a picture of {}",
    "a photo of {}",
    "a drawing of {}",
    "a sketch of {}",
    "a figure of {}",
    "{} letter",
    "a number {} letter",
    # ... and more
]
Each word is formatted with all templates to create diverse prompts for better CLIP encoding:
from laft.prompts import get_prompts

prompts = get_prompts("color_mnist", "guide_number")

# First word is "zero"
zero_prompts = prompts['normal'][0]
print(zero_prompts)
# ['zero', 'a number zero', 'an image of zero', 'a picture of zero', ...]

Guidance Modes

  • guide_*: Returns structured prompts (list of lists) for each word
  • ignore_*: Returns flattened prompts (single list) for token-level matching
# Guide mode: structured
guide_prompts = get_prompts("color_mnist", "guide_number")
print(type(guide_prompts['normal']))  # list of lists

# Ignore mode: flattened
ignore_prompts = get_prompts("color_mnist", "ignore_number")
print(type(ignore_prompts['normal']))  # flat list

Build docs developers (and LLMs) love