Skip to main content
All loss functions return a scalar output of shape [1]. Cross-entropy and BCE have dedicated fused GPU kernels with built-in backward support. MSE and L1 are composed from primitive ops and differentiate through the standard autodiff engine.

g.cross_entropy_loss

Computes the mean cross-entropy loss between predicted logits and target labels. Internally applies log-softmax over the class dimension and then computes the mean negative log-likelihood. Both inputs must have the same shape [batch, classes].
logits
NodeId
required
Raw (pre-softmax) predictions of shape [batch, classes].
labels
NodeId
required
Target distribution of shape [batch, classes]. Typically one-hot encoded or soft labels.
NodeId
NodeId
Scalar loss of shape [1].
let mut g = Graph::new();
let logits = g.input("logits", &[4, 10]);
let labels = g.input("labels", &[4, 10]);

let loss = g.cross_entropy_loss(logits, labels);
// loss shape: [1]
Full training graph example:
let mut g = Graph::new();
let x = g.input("x", &[batch, 784]);
let labels = g.input("labels", &[batch, 10]);

let fc1 = nn::Linear::new(&mut g, "fc1", 784, 128);
let fc2 = nn::Linear::new(&mut g, "fc2", 128, 10);

let h = fc1.forward(&mut g, x);
let h = g.relu(h);
let logits = fc2.forward(&mut g, h);
let loss = g.cross_entropy_loss(logits, labels);
g.set_outputs(vec![loss]);
cross_entropy_loss has a native GPU kernel with a fused backward pass. It is the recommended loss for multi-class classification.

g.bce_loss

Binary cross-entropy loss for binary or multi-label classification. The formula is:
loss = -mean(t * log(p) + (1 - t) * log(1 - p))
pred must contain values in the range (0, 1). Apply g.sigmoid to raw logits before passing them here.
pred
NodeId
required
Predicted probabilities in (0, 1), typically after sigmoid. Must have the same shape as labels.
labels
NodeId
required
Binary target values in {0, 1} (or soft labels in [0, 1]). Same shape as pred.
NodeId
NodeId
Scalar loss of shape [1].
let mut g = Graph::new();
let logits = g.input("logits", &[batch, num_labels]);
let labels = g.input("labels", &[batch, num_labels]);

// Apply sigmoid first — bce_loss expects probabilities, not raw logits
let pred = g.sigmoid(logits);
let loss = g.bce_loss(pred, labels);
// loss shape: [1]
bce_loss has a native GPU kernel with a fused backward pass. It is the recommended loss for binary and multi-label classification tasks.

g.mse_loss

Mean squared error: loss = mean((pred - target)²). Implemented as a composition of primitive ops (neg, add, mul, mean_all) and differentiates through the standard autodiff engine.
pred
NodeId
required
Predicted values. Any shape, must match target.
target
NodeId
required
Ground truth values. Same shape as pred.
NodeId
NodeId
Scalar loss of shape [1].
let mut g = Graph::new();
let pred = g.input("pred", &[batch, output_dim]);
let target = g.input("target", &[batch, output_dim]);

let loss = g.mse_loss(pred, target);
// loss shape: [1]
Regression training example:
let mut g = Graph::new();
let x = g.input("x", &[batch, in_dim]);
let target = g.input("target", &[batch, out_dim]);

let mlp = nn::Mlp::new(&mut g, "mlp", in_dim, 256, out_dim, nn::Activation::Relu);
let pred = mlp.forward(&mut g, x);
let loss = g.mse_loss(pred, target);
g.set_outputs(vec![loss]);

g.l1_loss

Mean absolute error: loss = mean(|pred - target|). Implemented as a composition of primitive ops (neg, add, abs, mean_all) and differentiates through the standard autodiff engine.
pred
NodeId
required
Predicted values. Any shape, must match target.
target
NodeId
required
Ground truth values. Same shape as pred.
NodeId
NodeId
Scalar loss of shape [1].
let mut g = Graph::new();
let pred = g.input("pred", &[batch, out_dim]);
let target = g.input("target", &[batch, out_dim]);

let loss = g.l1_loss(pred, target);
// loss shape: [1]

Backward support summary

Loss functionBackward supportNotes
cross_entropy_lossNative fused kernelPreferred for multi-class classification.
bce_lossNative fused kernelPreferred for binary / multi-label classification.
mse_lossVia primitive autodiffComposed from neg, add, mul, mean_all.
l1_lossVia primitive autodiffComposed from neg, add, abs, mean_all. Gradient is undefined at zero.

Build docs developers (and LLMs) love