differentiate() statically appends backward nodes to the graph at build time. The resulting graph contains both the forward and backward passes as a single static compute sequence.
The differentiate() function
differentiate() takes a forward graph that ends in a scalar loss and returns a new graph containing:
- All original forward nodes (copied verbatim).
- New gradient nodes appended after the forward nodes.
- An updated output list:
[loss, grad_param_0, grad_param_1, ...]— one gradient output perParameternode, in the order they appear in the forward graph.
Add:
Gradient rules
EachOp has a corresponding gradient rule. Here are the most important ones:
MatMul
MatMul
For
C = A @ B:dA = dC @ B^T→MatMulBT(dC, B)dB = A^T @ dC→MatMulAT(A, dC)
FusedMatMulAdd
FusedMatMulAdd
For
C = A @ B + D:- Same
dAanddBasMatMul, plusdD = dC(passthrough). The gradient of the addend passes straight through.
BiasAdd
BiasAdd
For
out = input + bias (broadcast):d_input = d_outd_bias = SumRows(d_out)— column sums of the upstream gradient.
RmsNorm
RmsNorm
For
y = RmsNorm(x, w, eps):d_w→RmsNormGradW(dy, x, w)— exact formula:sum_i(dy[i,j] * x[i,j] * rsqrt_i)d_x→RmsNormGradX(dy, x, w)— exact formula:rsqrt_i * (dy[i,j]*w[j] - x[i,j]*s_i)
SwiGLU
SwiGLU
For
out = silu(gate) * up:d_gate→SwiGLUGradGate(d_out, gate, up)d_up→SwiGLUGradUp(d_out, gate)
SwiGLUConcat
SwiGLUConcat
For
out[M,N] = silu(input[:,:N]) * input[:,N:]:d_input→SwiGLUConcatGrad(d_out, input)— produces[M, 2*N].
Silu
Silu
d_x → SiluGrad(d_out, x) — computes d/dx silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))).MultiHeadAttn
MultiHeadAttn
For
O = MultiHeadAttn(Q, K, V), three separate gradient nodes are emitted:d_Q→MultiHeadAttnGradQ(dO, Q, K, V)(references the forward node id for the LSE buffer)d_K→MultiHeadAttnGradK(dO, Q, K, V)d_V→MultiHeadAttnGradV(dO, Q, K, V)
CrossEntropyLoss
CrossEntropyLoss
d_logits = softmax(logits) - labelsRelu
Relu
d_x = d_out * (x > 0) — uses a Greater node to create the mask.MeanAll
MeanAll
d_x = (1/N) broadcast to the shape of x — emitted as a constant tensor.Embedding
Embedding
d_table → ScatterAdd(indices, d_out, vocab_size) — accumulates gradient rows by index.Conv2d
Conv2d
d_input→Conv2dGradInput(d_out, kernel)d_kernel→Conv2dGradWeight(d_out, input)
Input nodes, Constant nodes, and inference-only ops (RoPE, CausalAttention, LayerNorm, CacheWrite, etc.) do not accumulate gradients. Calling differentiate() on a graph containing inference-only ops emits a log warning and leaves those ops without gradients.Autodiff runs on the optimized graph
A critical detail:differentiate() is called on the optimized forward graph, not the original one. This means the backward pass differentiates through fused ops like SwiGLUConcat and FusedMatMulAdd directly, rather than re-deriving them from their unfused components.
If you call differentiate() on an unoptimized graph, the optimizer still runs on the full graph afterward and will fuse the backward matmuls, but you will not benefit from SwiGLUConcat on the forward pass.
The full pipeline
Optimize the forward graph
SwiGLU, MatMul+Add, and other patterns before differentiation.Optimize the full graph
MatMulBT+Add patterns and similar opportunities.build_session_with_report(). Call compile_training_graph() to run the pipeline up to step 6 (exclusive of GPU session creation) in environments without a GPU:
Gradient outputs and parameter alignment
The output list produced bydifferentiate() has a fixed structure:
Parameter node in the forward graph gets exactly one gradient output, in the same positional order. If the optimizer fused a parameter away (marking it Nop), a scalar zero placeholder is emitted so that the positional mapping stays intact for compile.rs to build param_grad_pairs.
Next steps
GPU execution
How the compiled plan is executed on GPU hardware.
Trainers and optimizers
How to use SGD and Adam with a compiled training session.