HybridLossFunction
Hybrid loss function combining multiple loss components for robust MOS prediction. Components:- Smooth L1 loss for basic regression
- Ranking loss for preserving relative order
- Scale-aware loss for emphasizing extreme quality values
Parameters
Beta parameter for smooth L1 loss
Margin for ranking loss
Weights for different quality ranges. Default:
{'low_quality': 1.5, 'high_quality': 1.5, 'normal': 1.0}Whether to use adaptive loss weighting that adjusts during training
Methods
__call__(pred, target, epoch=0)
Compute hybrid loss. Parameters:pred(torch.Tensor): Predicted MOS scores with shape (B, 5)target(torch.Tensor): Target MOS scores with shape (B, 5)epoch(int): Current training epoch (default: 0)
Tuple[torch.Tensor, Dict[str, float]] containing:
total_loss(torch.Tensor): Combined loss valueloss_components(dict): Dictionary with:total_loss(float): Total loss valuesmooth_l1_loss(float): Smooth L1 componentranking_loss(float): Ranking componentscale_loss(float): Scale-aware componentalpha(float): Weight for smooth L1 lossbeta(float): Weight for ranking lossgamma(float): Weight for scale loss
train_epoch
Train model for one epoch with gradient accumulation and mixed precision.Parameters
Model to train
Training data loader
Optimizer for training
Learning rate scheduler (optional)
Gradient scaler for mixed precision training
Loss function
Number of gradient accumulation steps
Current epoch number
Device to use for training
Maximum gradient norm for clipping
Logging interval in batches
Returns
Dictionary containing:
train_loss(float): Average total losstrain_smooth_l1(float): Average smooth L1 losstrain_ranking(float): Average ranking losstrain_scale(float): Average scale lossnum_batches(int): Number of batches processed
evaluate
Evaluate model on validation set.Parameters
Model to evaluate
Validation data loader
Loss function
Device to use for evaluation
Returns
Dictionary containing:
val_loss(float): Average validation lossval_smooth_l1(float): Average smooth L1 lossval_ranking(float): Average ranking lossval_scale(float): Average scale lossnum_val_batches(int): Number of batches processed
List of predicted overall MOS scores
List of target overall MOS scores
create_optimizer
Create optimizer with optional discriminative learning rates.Parameters
Model to optimize
Base learning rate
Weight decay for regularization
Dictionary with component-specific LR multipliers. Keys: ‘text’, ‘video’, ‘head’
Returns
Configured AdamW optimizer
create_scheduler
Create learning rate scheduler.Parameters
Optimizer to schedule
Total number of training steps
Number of warmup steps
Type of scheduler: ‘cosine’, ‘linear’, or ‘constant’
Returns
Configured learning rate scheduler
AdaptiveLossManager
Adaptive loss weight manager that dynamically adjusts loss component weights during training.Parameters
Initial weight for smooth L1 loss
Initial weight for ranking loss
Rate of adaptation for weight updates
Methods
update_weights(mae_loss, ranking_loss)
Update loss weights based on recent loss trends. Parameters:mae_loss(float): Current MAE loss valueranking_loss(float): Current ranking loss value
get_weights()
Get current loss weights. Returns:Tuple[float, float] - Current (alpha, beta) weights