Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/torchgeo/torchgeo/llms.txt

Use this file to discover all available pages before exploring further.

TorchGeo provides specialized loss functions for learning on uncertain labels and prior distributions.

QR Loss Functions

These losses are designed for learning with label uncertainty using implicit generative models, as described in “Resolving label uncertainty with implicit generative models”.

QRLoss

The QR (forward) loss between class probabilities and predictions.
eps
float
default:"1e-8"
Small constant for numerical stability to prevent log(0) when computing the loss. Must be greater than or equal to 0.
Formula: Computes the forward KL-divergence-like loss:
QR = E_q[log q] - E_q[log p]
where:
  • q is the predicted probability distribution
  • p is the target prior distribution
Method:
def forward(probs: Tensor, target: Tensor) -> Tensor:
    """
    Args:
        probs: Probabilities of predictions (B x C x H x W)
        target: Prior probabilities (B x C x H x W)
    
    Returns:
        QR loss value
    """
Example:
import torch
from torchgeo.losses import QRLoss

loss_fn = QRLoss(eps=1e-8)

# Predicted probabilities (batch_size=4, num_classes=5, height=256, width=256)
probs = torch.softmax(logits, dim=1)

# Prior probabilities from uncertain labels
target = uncertain_labels  # Same shape as probs

loss = loss_fn(probs, target)

RQLoss

The RQ (backward) loss between class probabilities and predictions.
eps
float
default:"1e-8"
Small constant for numerical stability to prevent division by zero and log(0). Must be greater than or equal to 0.
Formula: Computes the reverse KL-divergence-like loss:
RQ = E_r[log r - log q]
where:
  • q is the predicted probability distribution
  • r is computed from the normalized product of predictions and targets
Method:
def forward(probs: Tensor, target: Tensor) -> Tensor:
    """
    Args:
        probs: Probabilities of predictions (B x C x H x W)
        target: Prior probabilities (B x C x H x W)
    
    Returns:
        RQ loss value
    """
Example:
import torch
from torchgeo.losses import RQLoss

loss_fn = RQLoss(eps=1e-8)

# Predicted probabilities
probs = torch.softmax(logits, dim=1)

# Prior probabilities
target = uncertain_labels

loss = loss_fn(probs, target)

When to Use Each Loss

QRLoss (Forward)

Use when:
  • You want to minimize forward KL divergence
  • You have prior probability distributions from uncertain labels
  • You want the model predictions to match the prior distribution
Characteristics:
  • Penalizes predictions that place probability mass where the target has low probability
  • More robust to outliers in predictions
  • Encourages mean-seeking behavior

RQLoss (Backward)

Use when:
  • You want to minimize reverse KL divergence
  • You have uncertain labels represented as probability distributions
  • You want to avoid missing modes in the target distribution
Characteristics:
  • Penalizes missing modes in the target distribution
  • Mode-seeking behavior
  • More sensitive to target distribution structure

Usage with Trainers

These losses are typically used with custom tasks for uncertain label scenarios:
from torchgeo.trainers import BaseTask
from torchgeo.losses import QRLoss
import torch
import torch.nn.functional as F

class UncertainLabelTask(BaseTask):
    def __init__(self, model='resnet50', in_channels=3, num_classes=10, lr=1e-3):
        self.model_name = model
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.lr = lr
        super().__init__()
    
    def configure_models(self):
        import timm
        self.model = timm.create_model(
            self.model_name,
            num_classes=self.num_classes,
            in_chans=self.in_channels,
            pretrained=True
        )
    
    def configure_losses(self):
        self.criterion = QRLoss(eps=1e-8)
    
    def training_step(self, batch, batch_idx):
        x = batch['image']
        y = batch['label_probs']  # Prior probabilities, not hard labels
        
        # Get logits and convert to probabilities
        logits = self(x)
        probs = F.softmax(logits, dim=1)
        
        # Compute QR loss
        loss = self.criterion(probs, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch['image']
        y = batch['label_probs']
        
        logits = self(x)
        probs = F.softmax(logits, dim=1)
        
        loss = self.criterion(probs, y)
        self.log('val_loss', loss)
        return loss

Combining with Other Losses

You can combine QR/RQ losses with other losses:
from torchgeo.losses import QRLoss
import torch.nn as nn

class HybridLoss(nn.Module):
    def __init__(self, alpha=0.5, eps=1e-8):
        super().__init__()
        self.alpha = alpha
        self.qr_loss = QRLoss(eps=eps)
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, logits, hard_labels, soft_labels):
        # Standard cross-entropy on hard labels
        ce = self.ce_loss(logits, hard_labels)
        
        # QR loss on soft/uncertain labels
        probs = F.softmax(logits, dim=1)
        qr = self.qr_loss(probs, soft_labels)
        
        # Weighted combination
        return self.alpha * ce + (1 - self.alpha) * qr

Semantic Segmentation with Uncertain Labels

import torch
import torch.nn.functional as F
from torchgeo.losses import QRLoss

# Segmentation with uncertain labels
logits = model(images)  # (B, C, H, W)
probs = F.softmax(logits, dim=1)

# Uncertain labels as probability distributions
uncertain_labels = soft_labels  # (B, C, H, W)

loss_fn = QRLoss(eps=1e-8)
loss = loss_fn(probs, uncertain_labels)

Best Practices

  1. Input format: Both losses expect probabilities (not logits). Apply F.softmax to logits first.
  2. Shape: Inputs should be (B, C, H, W) for segmentation or (B, C) for classification.
  3. Normalization: Ensure target probabilities sum to 1 across the class dimension.
  4. Numerical stability: Use eps >= 1e-8 to prevent log(0) and division by zero.
  5. Combined losses: When combining with other losses, tune the weighting carefully.

Common Use Cases

Noisy Labels

When your dataset has noisy or uncertain labels:
# Convert noisy labels to soft distributions
soft_labels = create_soft_distribution(noisy_labels, confidence_scores)
loss = QRLoss()(predictions, soft_labels)

Multiple Annotators

When you have multiple annotators with disagreement:
# Aggregate multiple annotations into probability distribution
aggregated_probs = aggregate_annotations(annotations)
loss = RQLoss()(predictions, aggregated_probs)

Weakly Supervised Learning

When you have weak supervision signals:
# Create prior from weak signals
prior_probs = create_prior_from_weak_signals(weak_labels)
loss = QRLoss()(predictions, prior_probs)

References

For more information about these losses, see:

Build docs developers (and LLMs) love