Documentation Index
Fetch the complete documentation index at: https://mintlify.com/torchgeo/torchgeo/llms.txt
Use this file to discover all available pages before exploring further.
TorchGeo provides specialized loss functions for learning on uncertain labels and prior distributions.
QR Loss Functions
These losses are designed for learning with label uncertainty using implicit generative models, as described in “Resolving label uncertainty with implicit generative models”.
QRLoss
The QR (forward) loss between class probabilities and predictions.
Small constant for numerical stability to prevent log(0) when computing the loss. Must be greater than or equal to 0.
Formula:
Computes the forward KL-divergence-like loss:
QR = E_q[log q] - E_q[log p]
where:
q is the predicted probability distribution
p is the target prior distribution
Method:
def forward(probs: Tensor, target: Tensor) -> Tensor:
"""
Args:
probs: Probabilities of predictions (B x C x H x W)
target: Prior probabilities (B x C x H x W)
Returns:
QR loss value
"""
Example:
import torch
from torchgeo.losses import QRLoss
loss_fn = QRLoss(eps=1e-8)
# Predicted probabilities (batch_size=4, num_classes=5, height=256, width=256)
probs = torch.softmax(logits, dim=1)
# Prior probabilities from uncertain labels
target = uncertain_labels # Same shape as probs
loss = loss_fn(probs, target)
RQLoss
The RQ (backward) loss between class probabilities and predictions.
Small constant for numerical stability to prevent division by zero and log(0). Must be greater than or equal to 0.
Formula:
Computes the reverse KL-divergence-like loss:
where:
q is the predicted probability distribution
r is computed from the normalized product of predictions and targets
Method:
def forward(probs: Tensor, target: Tensor) -> Tensor:
"""
Args:
probs: Probabilities of predictions (B x C x H x W)
target: Prior probabilities (B x C x H x W)
Returns:
RQ loss value
"""
Example:
import torch
from torchgeo.losses import RQLoss
loss_fn = RQLoss(eps=1e-8)
# Predicted probabilities
probs = torch.softmax(logits, dim=1)
# Prior probabilities
target = uncertain_labels
loss = loss_fn(probs, target)
When to Use Each Loss
QRLoss (Forward)
Use when:
- You want to minimize forward KL divergence
- You have prior probability distributions from uncertain labels
- You want the model predictions to match the prior distribution
Characteristics:
- Penalizes predictions that place probability mass where the target has low probability
- More robust to outliers in predictions
- Encourages mean-seeking behavior
RQLoss (Backward)
Use when:
- You want to minimize reverse KL divergence
- You have uncertain labels represented as probability distributions
- You want to avoid missing modes in the target distribution
Characteristics:
- Penalizes missing modes in the target distribution
- Mode-seeking behavior
- More sensitive to target distribution structure
Usage with Trainers
These losses are typically used with custom tasks for uncertain label scenarios:
from torchgeo.trainers import BaseTask
from torchgeo.losses import QRLoss
import torch
import torch.nn.functional as F
class UncertainLabelTask(BaseTask):
def __init__(self, model='resnet50', in_channels=3, num_classes=10, lr=1e-3):
self.model_name = model
self.in_channels = in_channels
self.num_classes = num_classes
self.lr = lr
super().__init__()
def configure_models(self):
import timm
self.model = timm.create_model(
self.model_name,
num_classes=self.num_classes,
in_chans=self.in_channels,
pretrained=True
)
def configure_losses(self):
self.criterion = QRLoss(eps=1e-8)
def training_step(self, batch, batch_idx):
x = batch['image']
y = batch['label_probs'] # Prior probabilities, not hard labels
# Get logits and convert to probabilities
logits = self(x)
probs = F.softmax(logits, dim=1)
# Compute QR loss
loss = self.criterion(probs, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x = batch['image']
y = batch['label_probs']
logits = self(x)
probs = F.softmax(logits, dim=1)
loss = self.criterion(probs, y)
self.log('val_loss', loss)
return loss
Combining with Other Losses
You can combine QR/RQ losses with other losses:
from torchgeo.losses import QRLoss
import torch.nn as nn
class HybridLoss(nn.Module):
def __init__(self, alpha=0.5, eps=1e-8):
super().__init__()
self.alpha = alpha
self.qr_loss = QRLoss(eps=eps)
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, logits, hard_labels, soft_labels):
# Standard cross-entropy on hard labels
ce = self.ce_loss(logits, hard_labels)
# QR loss on soft/uncertain labels
probs = F.softmax(logits, dim=1)
qr = self.qr_loss(probs, soft_labels)
# Weighted combination
return self.alpha * ce + (1 - self.alpha) * qr
Semantic Segmentation with Uncertain Labels
import torch
import torch.nn.functional as F
from torchgeo.losses import QRLoss
# Segmentation with uncertain labels
logits = model(images) # (B, C, H, W)
probs = F.softmax(logits, dim=1)
# Uncertain labels as probability distributions
uncertain_labels = soft_labels # (B, C, H, W)
loss_fn = QRLoss(eps=1e-8)
loss = loss_fn(probs, uncertain_labels)
Best Practices
- Input format: Both losses expect probabilities (not logits). Apply
F.softmax to logits first.
- Shape: Inputs should be
(B, C, H, W) for segmentation or (B, C) for classification.
- Normalization: Ensure target probabilities sum to 1 across the class dimension.
- Numerical stability: Use
eps >= 1e-8 to prevent log(0) and division by zero.
- Combined losses: When combining with other losses, tune the weighting carefully.
Common Use Cases
Noisy Labels
When your dataset has noisy or uncertain labels:
# Convert noisy labels to soft distributions
soft_labels = create_soft_distribution(noisy_labels, confidence_scores)
loss = QRLoss()(predictions, soft_labels)
Multiple Annotators
When you have multiple annotators with disagreement:
# Aggregate multiple annotations into probability distribution
aggregated_probs = aggregate_annotations(annotations)
loss = RQLoss()(predictions, aggregated_probs)
Weakly Supervised Learning
When you have weak supervision signals:
# Create prior from weak signals
prior_probs = create_prior_from_weak_signals(weak_labels)
loss = QRLoss()(predictions, prior_probs)
References
For more information about these losses, see: