Skip to main content

build_semantic_dataset

Builds a semantic anomaly detection dataset for training or evaluation.
from laft.datasets import build_semantic_dataset

dataset = build_semantic_dataset(
    name="color_mnist",
    split="train",
    root="./data",
    transform=None,
    config=None
)

Parameters

name
Literal['color_mnist', 'waterbirds', 'celeba']
required
Name of the semantic anomaly dataset to build.Supported datasets:
  • color_mnist: MNIST digits with color attributes
  • waterbirds: Birds in different backgrounds
  • celeba: Celebrity faces with attributes
split
Literal['train', 'valid', 'test']
required
Dataset split to load.
root
str
default:"./data"
Root directory where the dataset files are stored or will be downloaded.
transform
Callable | None
default:"None"
Optional transform to apply to images. Typically a torchvision transform pipeline.
config
dict | None
default:"None"
Dataset-specific configuration dictionary. When None, uses default configuration.For color_mnist, config structure:
{
    "number": {0: False, 1: False, ..., 9: True},  # False=normal, True=anomaly
    "color": {"red": False, "green": True, "blue": True}
}
**kwargs
dict
Additional keyword arguments. For color_mnist, accepts seed parameter for reproducible train/valid splits.

Returns

dataset
SemanticAnomalyDataset
A dataset instance that inherits from SemanticAnomalyDataset.

Examples

Load ColorMNIST with default config

from laft.datasets import build_semantic_dataset
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

# Load training set (normal samples only)
train_dataset = build_semantic_dataset(
    name="color_mnist",
    split="train",
    transform=transform
)

# Load test set (normal and anomaly samples)
test_dataset = build_semantic_dataset(
    name="color_mnist",
    split="test",
    transform=transform
)

Load with custom config

# Custom configuration: digits 0-7 are normal, 8-9 are anomalies
custom_config = {
    "number": {
        0: False, 1: False, 2: False, 3: False,
        4: False, 5: False, 6: False, 7: False,
        8: True, 9: True
    },
    "color": {
        "red": False,
        "green": True,
        "blue": True
    }
}

dataset = build_semantic_dataset(
    name="color_mnist",
    split="train",
    config=custom_config,
    seed=123  # Custom seed for train/valid split
)

Load other semantic datasets

# Waterbirds dataset
waterbirds = build_semantic_dataset(
    name="waterbirds",
    split="test",
    transform=transform
)

# CelebA dataset
celeba = build_semantic_dataset(
    name="celeba",
    split="test",
    transform=transform
)

build_industrial_dataset

Builds an industrial anomaly detection dataset for a specific product category.
from laft.datasets import build_industrial_dataset

dataset = build_industrial_dataset(
    name="mvtec",
    category="bottle",
    split="test",
    root="./data"
)

Parameters

name
Literal['mvtec', 'visa']
required
Name of the industrial anomaly dataset.Supported datasets:
  • mvtec: MVTec Anomaly Detection dataset
  • visa: VisA (Visual Anomaly) dataset
category
str
required
Product category to load.For MVTec, supported categories: bottle, cable, capsule, carpet, grid, hazelnut, leather, metal_nut, pill, screw, tile, toothbrush, transistor, wood, zipper
split
Literal['train', 'test']
required
Dataset split to load.
root
str
default:"./data"
Root directory where the dataset files are stored.
transform
Callable | None
default:"None"
Optional transform to apply to images.
mask_transform
Callable | None
default:"None"
Optional transform to apply to anomaly masks.

Returns

dataset
IndustrialAnomalyDataset
A dataset instance that inherits from IndustrialAnomalyDataset.

Examples

Load MVTec dataset

from laft.datasets import build_industrial_dataset
from torchvision import transforms

image_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize(224)
])

# Load test set for bottle category
dataset = build_industrial_dataset(
    name="mvtec",
    category="bottle",
    split="test",
    root="./data",
    transform=image_transform,
    mask_transform=mask_transform
)

# Iterate through dataset
for image, mask, label in dataset:
    print(f"Image shape: {image.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Is anomaly: {label.item()}")
    break

Load multiple categories

categories = ["bottle", "cable", "capsule"]

datasets = {}
for category in categories:
    datasets[category] = build_industrial_dataset(
        name="mvtec",
        category=category,
        split="test",
        transform=image_transform
    )

SemanticAnomalyDataset

Base class for semantic anomaly detection datasets.

Attributes

attr_names
list[str]
List of attribute names (e.g., ["number", "color"] for ColorMNIST).
attrs
torch.Tensor
Boolean tensor of shape [num_samples, num_attrs] where False indicates normal and True indicates anomaly for each attribute.
root
str
Root directory of the dataset.
split
str
Dataset split (train, valid, or test).
transform
Callable | None
Transform applied to images.

Methods

get_normal_subset()

Returns a subset containing only normal samples (all attributes are False).
dataset = build_semantic_dataset("color_mnist", "train")
normal_subset = dataset.get_normal_subset()
print(f"Total samples: {len(dataset)}")
print(f"Normal samples: {len(normal_subset)}")
subset
torch.utils.data.Subset
A PyTorch Subset containing only samples where all attributes are normal.

load_image(index: int)

Loads the image at the given index. Must be implemented by subclasses.
image
PIL.Image.Image
The loaded image.

__getitem__(index: int)

Returns a tuple of (image, target) where target is the attribute tensor.
image, attrs = dataset[0]
print(f"Image shape: {image.shape}")
print(f"Attributes: {attrs}")  # e.g., tensor([False, True])
item
tuple[Any, torch.Tensor]
Tuple of (transformed image, attribute tensor).

IndustrialAnomalyDataset

Base class for industrial anomaly detection datasets.

Attributes

labels
torch.Tensor
Boolean tensor of shape [num_samples] where False indicates normal and True indicates anomaly.
root
str
Root directory of the dataset.
split
str
Dataset split (train or test).
transform
Callable | None
Transform applied to images.
mask_transform
Callable | None
Transform applied to anomaly masks.

Methods

load_image(index: int)

Loads the image at the given index. Must be implemented by subclasses.
image
PIL.Image.Image
The loaded image.

load_mask(index: int)

Loads the anomaly mask at the given index. Must be implemented by subclasses.
mask
torch.Tensor
Boolean tensor indicating anomalous regions.

__getitem__(index: int)

Returns a tuple of (image, mask, label).
image, mask, label = dataset[0]
print(f"Image shape: {image.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Is anomaly: {label.item()}")
item
tuple[Any, torch.Tensor, torch.Tensor]
Tuple of (transformed image, mask, label). For normal samples, mask is all zeros.

Example Usage

from laft.datasets import build_industrial_dataset
from torch.utils.data import DataLoader

dataset = build_industrial_dataset(
    name="mvtec",
    category="bottle",
    split="test"
)

loader = DataLoader(dataset, batch_size=4, shuffle=False)

for images, masks, labels in loader:
    anomaly_count = labels.sum().item()
    print(f"Batch has {anomaly_count} anomalies")
    break

Build docs developers (and LLMs) love