Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/ageron/handson-ml3/llms.txt

Use this file to discover all available pages before exploring further.

Keras’s high-level training loop (model.fit()) is convenient but sometimes you need finer control — perhaps an unusual loss term, a layer with non-standard state, or a training step that doesn’t fit the standard supervised learning mould. Chapter 12 teaches you to drop down from the Keras abstraction to TensorFlow primitives: tensors, variables, automatic differentiation with tf.GradientTape, and graph compilation with tf.function. You’ll build custom loss functions, custom layers with trainable weights, and complete custom training loops from scratch.

What you’ll learn

  • TensorFlow tensors and the basics of operating on them like NumPy
  • Mutable state with tf.Variable
  • Custom loss functions and metrics that plug into model.compile()
  • Custom layers by subclassing tf.keras.layers.Layer
  • Custom models by subclassing tf.keras.Model
  • Automatic differentiation with tf.GradientTape
  • Writing a full custom training loop without model.fit()
  • Accelerating Python functions with tf.function (graph mode compilation)
  • TensorFlow’s additional data types: strings, ragged tensors, sparse tensors

Key concepts

Tensors and variables

A tf.Tensor is an immutable, multi-dimensional array analogous to a NumPy ndarray but capable of running on GPUs and TPUs. TensorFlow operations mirror NumPy closely: indexing, slicing, broadcasting, and most math ops work as expected. A tf.Variable wraps a mutable tensor and is what neural network weights are stored in; it supports in-place operations like assign(), assign_add(), and scatter_nd_update().

Custom loss functions and metrics

The simplest customisation is a Python function that accepts y_true and y_pred tensors and returns a scalar loss. For more complex needs — losses with hyperparameters, or losses that need to be serialised — you subclass tf.keras.losses.Loss. The same pattern applies to metrics via tf.keras.metrics.Metric, where you override update_state() and result().

Custom layers

Subclassing tf.keras.layers.Layer lets you create layers with arbitrary learnable parameters and non-standard forward passes. The __init__ method is called once; build() is called lazily on the first call with the actual input shape, which is when you create your weight matrices; call() runs the forward pass. Marking weights with self.add_weight() ensures they are tracked by Keras and included in model.trainable_variables.

tf.GradientTape and custom training loops

tf.GradientTape records TensorFlow operations executed inside its context so that automatic differentiation (reverse-mode autodiff) can compute gradients. A minimal custom training step:
  1. Run the forward pass inside a GradientTape block to record the computation.
  2. Compute the loss.
  3. Call tape.gradient(loss, trainable_variables) to get gradients.
  4. Apply gradients with optimizer.apply_gradients().
This explicit loop gives you full control over multi-task losses, gradient accumulation, custom regularisation, and more.

tf.function

Decorating a Python function with @tf.function traces it the first time it is called and compiles it into a TensorFlow graph. Subsequent calls use the compiled graph, which can be significantly faster than eager execution, especially for training loops.

Code examples

TensorFlow tensors and variables

import tensorflow as tf

# Immutable tensor
t = tf.constant([[1., 2., 3.], [4., 5., 6.]])
print(t.shape)   # TensorShape([2, 3])
print(t.dtype)   # tf.float32

# Mutable variable
v = tf.Variable([[1., 2., 3.], [4., 5., 6.]])
v.assign(2 * v)         # in-place multiply
v[0, 1].assign(42)      # in-place element update

Custom loss function

def huber_fn(y_true, y_pred):
    error = y_true - y_pred
    is_small_error = tf.abs(error) < 1
    squared_loss = tf.square(error) / 2
    linear_loss = tf.abs(error) - 0.5
    return tf.where(is_small_error, squared_loss, linear_loss)

model.compile(loss=huber_fn, optimizer="nadam")

Custom layer with trainable weights

class MyDense(tf.keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)

    def build(self, batch_input_shape):
        self.kernel = self.add_weight(
            name="kernel",
            shape=[batch_input_shape[-1], self.units],
            initializer="glorot_normal")
        self.bias = self.add_weight(
            name="bias", shape=[self.units], initializer="zeros")

    def call(self, X):
        return self.activation(X @ self.kernel + self.bias)

Custom training loop with tf.GradientTape

n_epochs = 5
batch_size = 32
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
loss_fn = tf.keras.losses.mean_squared_error

for epoch in range(1, n_epochs + 1):
    for X_batch, y_batch in train_dataset:
        with tf.GradientTape() as tape:
            y_pred = model(X_batch, training=True)
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Running this notebook

1

Open in Colab

2

Install dependencies

pip install -r requirements.txt
3

Note on Keras version

This chapter uses standard Keras 3 (bundled with TensorFlow ≥ 2.16). No special environment variable is required.
4

Run cells in order

The notebook includes both the Chapter 12 material and Appendix C material on advanced TensorFlow data types. Run sections in order as later cells depend on variables defined earlier.

Exercises

Exercises ask you to implement a custom layer that normalises its inputs, write a custom training loop that also logs metrics to TensorBoard, and replicate Keras’s built-in Huber loss class from scratch. Solutions are at the end of the notebook.
tf.function traces functions on the first call. If your function conditionally branches on Python values (not tensor values), each distinct Python input will trigger a new trace. Use tf.cond for conditional logic that depends on tensor content.

Build docs developers (and LLMs) love