Skip to main content
This tutorial shows you how to implement custom aggregation strategies for federated learning, going beyond the default FedAvg algorithm.

Why Custom Strategies?

While FedAvg (Federated Averaging) works well for many scenarios, you might need custom strategies for:
  • Weighted aggregation based on data quality or quantity
  • Differential privacy with noise injection
  • Robust aggregation that handles outliers or malicious clients
  • Personalized models with client-specific parameters
  • Custom checkpointing for fault tolerance

Built-in Strategy: FedAvgWithModelSaving

Syft-Flwr provides a custom strategy that extends Flower’s FedAvg with automatic model checkpointing.

Source Code

Here’s the complete implementation from syft_flwr/strategy/fedavg.py:
from pathlib import Path
from loguru import logger
from safetensors.numpy import save_file
from flwr.common import parameters_to_ndarrays
from flwr.server.strategy import FedAvg

class FedAvgWithModelSaving(FedAvg):
    """FedAvg strategy that saves global model to disk after each round.
    
    This strategy behaves exactly like FedAvg but additionally stores
    the state of the global model to disk after each round for:
    - Progress tracking
    - Fault tolerance  
    - Model analysis
    """

    def __init__(self, save_path: str, *args, **kwargs):
        self.save_path = Path(save_path)
        self.save_path.mkdir(exist_ok=True, parents=True)
        super().__init__(*args, **kwargs)

    def _save_global_model(self, server_round: int, parameters):
        """Save model parameters to disk as safetensors."""
        ndarrays = parameters_to_ndarrays(parameters)
        tensor_dict = {f"layer_{i}": array for i, array in enumerate(ndarrays)}
        filename = self.save_path / f"parameters_round_{server_round}.safetensors"
        
        if not self.save_path.exists():
            logger.error(
                f"Save directory {self.save_path} does NOT exist! "
                f"Maybe it's deleted or moved."
            )
        else:
            save_file(tensor_dict, str(filename))
            logger.info(f"Checkpoint saved to: {filename}")

    def evaluate(self, server_round: int, parameters):
        """Evaluate and save model parameters."""
        self._save_global_model(server_round, parameters)
        return super().evaluate(server_round, parameters)

Key Features

  1. Automatic Checkpointing: Saves model after each round
  2. SafeTensors Format: Uses safe, fast serialization format
  3. Error Handling: Checks directory exists before saving
  4. Extensible: Easy to customize for your needs

Creating a Custom Strategy

Let’s walk through implementing custom strategies for different use cases.
1
Basic Custom Strategy Template
2
Start with this template:
3
from flwr.server.strategy import FedAvg
from flwr.common import (
    Parameters,
    FitRes,
    EvaluateRes,
    parameters_to_ndarrays,
    ndarrays_to_parameters,
)
from typing import List, Tuple, Optional, Dict
from loguru import logger

class CustomStrategy(FedAvg):
    """Template for custom FL strategy."""
    
    def __init__(self, *args, **kwargs):
        # Initialize your custom parameters
        super().__init__(*args, **kwargs)
        logger.info("Initialized CustomStrategy")
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model updates from clients."""
        
        # Custom aggregation logic here
        # Default: call parent's FedAvg implementation
        return super().aggregate_fit(server_round, results, failures)
    
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation metrics from clients."""
        
        # Custom metric aggregation here
        return super().aggregate_evaluate(server_round, results, failures)
4
Weighted Aggregation Strategy
5
Aggregate based on dataset size and quality:
6
from flwr.common import Scalar
import numpy as np

class WeightedFedAvg(FedAvg):
    """FedAvg with custom weighted aggregation."""
    
    def __init__(self, quality_weights: Dict[str, float] = None, *args, **kwargs):
        """
        Args:
            quality_weights: Dict mapping client IDs to quality scores (0-1)
        """
        super().__init__(*args, **kwargs)
        self.quality_weights = quality_weights or {}
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate with quality-weighted averaging."""
        
        if not results:
            return None, {}
        
        # Convert results to numpy arrays
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        
        # Calculate weights: dataset size * quality score
        total_weight = 0
        weighted_params = None
        
        for idx, (client, fit_res) in enumerate(results):
            # Get quality weight (default to 1.0 if not specified)
            quality = self.quality_weights.get(client.cid, 1.0)
            
            # Combined weight: size * quality
            weight = fit_res.num_examples * quality
            total_weight += weight
            
            # Get parameters
            params = parameters_to_ndarrays(fit_res.parameters)
            
            # Weighted sum
            if weighted_params is None:
                weighted_params = [w * weight for w in params]
            else:
                weighted_params = [
                    wp + (w * weight) 
                    for wp, w in zip(weighted_params, params)
                ]
        
        # Average
        aggregated = [w / total_weight for w in weighted_params]
        
        logger.info(
            f"Round {server_round}: Aggregated {len(results)} clients "
            f"with total weight {total_weight:.2f}"
        )
        
        return ndarrays_to_parameters(aggregated), {}
7
Differential Privacy Strategy
8
Add noise for privacy preservation:
9
class DPFedAvg(FedAvg):
    """FedAvg with Differential Privacy."""
    
    def __init__(
        self,
        noise_multiplier: float = 0.1,
        clip_norm: float = 1.0,
        *args,
        **kwargs
    ):
        """
        Args:
            noise_multiplier: Scale of Gaussian noise to add
            clip_norm: Gradient clipping threshold
        """
        super().__init__(*args, **kwargs)
        self.noise_multiplier = noise_multiplier
        self.clip_norm = clip_norm
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate with differential privacy."""
        
        # First, perform standard aggregation
        parameters, metrics = super().aggregate_fit(
            server_round, results, failures
        )
        
        if parameters is None:
            return None, metrics
        
        # Add Gaussian noise for privacy
        params = parameters_to_ndarrays(parameters)
        noisy_params = []
        
        for param in params:
            # Clip gradients
            norm = np.linalg.norm(param)
            if norm > self.clip_norm:
                param = param * (self.clip_norm / norm)
            
            # Add Gaussian noise
            noise = np.random.normal(
                0, 
                self.noise_multiplier * self.clip_norm,
                param.shape
            )
            noisy_params.append(param + noise)
        
        logger.info(
            f"Round {server_round}: Added DP noise "
            f"(σ={self.noise_multiplier}, C={self.clip_norm})"
        )
        
        return ndarrays_to_parameters(noisy_params), metrics
10
Robust Aggregation Strategy
11
Handle outliers and malicious clients:
12
class RobustFedAvg(FedAvg):
    """FedAvg with robust aggregation using median."""
    
    def __init__(self, trimmed_mean_ratio: float = 0.1, *args, **kwargs):
        """
        Args:
            trimmed_mean_ratio: Fraction of extreme values to trim (0-0.5)
        """
        super().__init__(*args, **kwargs)
        self.trim_ratio = trimmed_mean_ratio
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate using trimmed mean (remove outliers)."""
        
        if not results:
            return None, {}
        
        # Extract all parameter arrays
        all_params = [
            parameters_to_ndarrays(fit_res.parameters)
            for _, fit_res in results
        ]
        
        # For each layer, compute trimmed mean
        num_layers = len(all_params[0])
        aggregated = []
        
        for layer_idx in range(num_layers):
            # Stack all clients' parameters for this layer
            layer_params = np.array([
                params[layer_idx] for params in all_params
            ])
            
            # Calculate how many to trim from each end
            num_trim = int(len(layer_params) * self.trim_ratio)
            
            if num_trim > 0:
                # Sort along client axis and trim extremes
                sorted_params = np.sort(layer_params, axis=0)
                trimmed = sorted_params[num_trim:-num_trim]
                # Average the remaining
                layer_avg = np.mean(trimmed, axis=0)
            else:
                # Just average if no trimming
                layer_avg = np.mean(layer_params, axis=0)
            
            aggregated.append(layer_avg)
        
        logger.info(
            f"Round {server_round}: Robust aggregation with "
            f"{num_trim} clients trimmed from each end"
        )
        
        return ndarrays_to_parameters(aggregated), {}
13
Strategy with Custom Checkpointing
14
Extend FedAvgWithModelSaving with additional features:
15
import json
from datetime import datetime

class AdvancedCheckpointing(FedAvgWithModelSaving):
    """Strategy with enhanced checkpointing and metrics tracking."""
    
    def __init__(self, save_path: str, *args, **kwargs):
        super().__init__(save_path, *args, **kwargs)
        self.metrics_history = []
    
    def _save_global_model(self, server_round: int, parameters):
        """Save model with additional metadata."""
        # Save model weights
        super()._save_global_model(server_round, parameters)
        
        # Save metadata
        metadata = {
            "round": server_round,
            "timestamp": datetime.now().isoformat(),
            "num_parameters": sum(
                p.size for p in parameters_to_ndarrays(parameters)
            ),
        }
        
        metadata_file = self.save_path / f"metadata_round_{server_round}.json"
        with open(metadata_file, "w") as f:
            json.dump(metadata, f, indent=2)
        
        logger.info(f"Metadata saved to: {metadata_file}")
    
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation and track metrics history."""
        
        loss, metrics = super().aggregate_evaluate(
            server_round, results, failures
        )
        
        # Track metrics over time
        round_metrics = {
            "round": server_round,
            "loss": loss,
            "metrics": metrics,
            "num_clients": len(results),
        }
        self.metrics_history.append(round_metrics)
        
        # Save metrics history
        history_file = self.save_path / "metrics_history.json"
        with open(history_file, "w") as f:
            json.dump(self.metrics_history, f, indent=2)
        
        return loss, metrics
16
Using Your Custom Strategy
17
Integrate the custom strategy in your server app:
18
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from pathlib import Path
import os

from your_project.task import Net, get_weights
from your_project.strategies import WeightedFedAvg  # Your custom strategy

def server_fn(context: Context):
    # Initialize model
    net = Net()
    params = ndarrays_to_parameters(get_weights(net))
    
    # Define quality weights for clients
    quality_weights = {
        "client_1": 1.0,  # High quality
        "client_2": 0.8,  # Medium quality
        "client_3": 0.6,  # Lower quality
    }
    
    # Use custom strategy
    strategy = WeightedFedAvg(
        quality_weights=quality_weights,
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_available_clients=2,
        initial_parameters=params,
    )
    
    num_rounds = context.run_config["num-server-rounds"]
    config = ServerConfig(num_rounds=num_rounds)
    
    return ServerAppComponents(config=config, strategy=strategy)

app = ServerApp(server_fn=server_fn)

Strategy Comparison

StrategyUse CaseProsCons
FedAvgGeneral purposeSimple, well-testedTreats all clients equally
FedAvgWithModelSavingProduction FLCheckpointing, recoverySmall overhead
WeightedFedAvgHeterogeneous data qualityAccounts for qualityRequires quality metrics
DPFedAvgPrivacy-sensitive dataPrivacy guaranteesReduced accuracy
RobustFedAvgAdversarial clientsOutlier resistantSlower convergence

Advanced Techniques

Client Selection

Select clients based on custom criteria:
class SelectiveFedAvg(FedAvg):
    def configure_fit(
        self,
        server_round: int,
        parameters: Parameters,
        client_manager: ClientManager,
    ):
        """Configure clients for training round."""
        
        # Custom client selection logic
        sample_size = max(int(self.fraction_fit * len(clients)), 1)
        
        # Select clients with highest data quality
        selected_clients = sorted(
            clients,
            key=lambda c: self.quality_weights.get(c.cid, 0),
            reverse=True
        )[:sample_size]
        
        return [(client, fit_config) for client in selected_clients]

Adaptive Learning Rates

Adjust aggregation based on training progress:
class AdaptiveFedAvg(FedAvg):
    def __init__(self, initial_lr: float = 1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.server_lr = initial_lr
    
    def aggregate_fit(self, server_round, results, failures):
        # Standard aggregation
        parameters, metrics = super().aggregate_fit(
            server_round, results, failures
        )
        
        # Decay learning rate
        self.server_lr *= 0.95
        
        # Scale updates by learning rate
        if parameters:
            params = parameters_to_ndarrays(parameters)
            scaled_params = [p * self.server_lr for p in params]
            parameters = ndarrays_to_parameters(scaled_params)
        
        return parameters, metrics

Personalization

Maintain both global and local models:
class PersonalizedFedAvg(FedAvg):
    def __init__(self, personalization_layers: int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pers_layers = personalization_layers
        self.client_local_params = {}  # Store per-client parameters
    
    def aggregate_fit(self, server_round, results, failures):
        # Only aggregate shared layers (not personalization layers)
        shared_results = []
        
        for client, fit_res in results:
            params = parameters_to_ndarrays(fit_res.parameters)
            
            # Store personalization layers
            self.client_local_params[client.cid] = params[-self.pers_layers:]
            
            # Only use shared layers for aggregation
            shared_params = params[:-self.pers_layers]
            fit_res.parameters = ndarrays_to_parameters(shared_params)
            shared_results.append((client, fit_res))
        
        # Aggregate shared layers only
        return super().aggregate_fit(server_round, shared_results, failures)

Testing Custom Strategies

Test your strategy with simulation:
test_strategy.py
import syft_flwr
from pathlib import Path

# Test with mock data
project_path = Path("./fl-diabetes-prediction")
mock_paths = ["mock1", "mock2"]

# Override strategy in server_app.py to use your custom one
syft_flwr.run(project_path, mock_paths)
Or test with unit tests:
import pytest
from your_strategy import WeightedFedAvg

def test_weighted_aggregation():
    strategy = WeightedFedAvg(
        quality_weights={"client_1": 1.0, "client_2": 0.5}
    )
    
    # Mock results
    results = [...]
    
    # Test aggregation
    params, metrics = strategy.aggregate_fit(1, results, [])
    
    assert params is not None

Best Practices

Start Simple

Begin with FedAvg or FedAvgWithModelSaving before implementing complex strategies.

Test Thoroughly

Use simulation mode to validate your strategy before deploying to real clients.

Log Everything

Add detailed logging to understand aggregation behavior during training.

Handle Failures

Always check for empty results and handle client failures gracefully.

What’s Next?

Common Pitfalls

Always call the parent class initializer to set up base functionality:
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)  # Don't forget this!
Always check if results is empty before aggregating:
if not results:
    return None, {}
Ensure all clients return parameters with the same shape. Log shapes when debugging:
for param in params:
    logger.debug(f"Parameter shape: {param.shape}")

Build docs developers (and LLMs) love