Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/MilesONerd/neurenix/llms.txt

Use this file to discover all available pages before exploring further.

Overview

Federated learning enables training machine learning models across multiple decentralized clients without sharing raw data. Neurenix provides a complete framework for implementing federated learning with support for various aggregation strategies, secure aggregation, and differential privacy.

Core Components

FederatedClient

The FederatedClient class represents a client participating in federated learning.
from neurenix.federated import FederatedClient, ClientConfig
import neurenix as nx

# Configure client
config = ClientConfig(
    client_id="client_001",
    batch_size=32,
    epochs=5,
    learning_rate=0.01,
    device="cuda"
)

# Create client
client = FederatedClient(config)

# Initialize with model and data
model = create_model()
criterion = nx.nn.CrossEntropyLoss()

client.initialize(
    model=model,
    criterion=criterion,
    train_data=train_loader,
    val_data=val_loader
)
Reference: neurenix/federated/client.py:85

FederatedServer

The FederatedServer class coordinates training across multiple clients.
from neurenix.federated import FederatedServer, ServerConfig, AggregationStrategy

# Configure server
config = ServerConfig(
    num_rounds=10,
    clients_per_round=10,
    client_fraction=0.1,
    aggregation_strategy=AggregationStrategy.FED_AVG,
    min_clients=2,
    eval_every=1,
    device="cuda"
)

# Create server
server = FederatedServer(config)

# Initialize with global model
server.initialize(
    model=global_model,
    criterion=criterion
)

# Register clients
for client in clients:
    server.add_client(client)
Reference: neurenix/federated/server.py:106

Client Configuration

Basic Configuration

config = ClientConfig(
    client_id="client_001",
    batch_size=32,
    epochs=5,                    # Local epochs per round
    learning_rate=0.01,
    momentum=0.9,
    weight_decay=0.0,
    max_grad_norm=1.0,           # Gradient clipping
    device="cuda"
)
Reference: neurenix/federated/client.py:26

FedProx Configuration

FedProx adds a proximal term to handle heterogeneous data:
config = ClientConfig(
    client_id="client_001",
    batch_size=32,
    epochs=5,
    learning_rate=0.01,
    proximal_mu=0.01,            # Proximal term coefficient
    device="cuda"
)
Reference: neurenix/federated/client.py:37

Privacy-Preserving Configuration

config = ClientConfig(
    client_id="client_001",
    batch_size=32,
    epochs=5,
    learning_rate=0.01,
    # Differential Privacy
    differential_privacy=True,
    dp_epsilon=1.0,
    dp_delta=1e-5,
    dp_mechanism="gaussian",     # 'gaussian' or 'laplace'
    # Secure Aggregation
    secure_aggregation=True,
    # Model Compression
    compression=True,
    compression_ratio=0.1,
    device="cuda"
)
Reference: neurenix/federated/client.py:39

Client Training

Local Training

# Train client for one federated round
metrics = client.train(global_round=0)

print(f"Train loss: {metrics['train_loss']:.4f}")
print(f"Train accuracy: {metrics['train_acc']:.4f}")
print(f"Validation loss: {metrics['val_loss']:.4f}")
print(f"Number of samples: {metrics['n_samples']}")
Reference: neurenix/federated/client.py:128

Model Updates

# Get model update for aggregation
model_update = client.get_model_update()

# Receive global model from server
client.set_model_parameters(global_model_params)
Reference: neurenix/federated/client.py:240

Server Configuration

Aggregation Strategies

from neurenix.federated import AggregationStrategy

# Available strategies
strategies = [
    AggregationStrategy.FED_AVG,      # Federated Averaging
    AggregationStrategy.FED_PROX,     # Federated Proximal
    AggregationStrategy.FED_NOVA,     # Federated Normalized Averaging
    AggregationStrategy.FED_OPT,      # Federated Optimization
    AggregationStrategy.FED_ADAGRAD,  # Federated Adagrad
    AggregationStrategy.FED_ADAM,     # Federated Adam
    AggregationStrategy.FED_YOGI,     # Federated Yogi
]

config = ServerConfig(
    num_rounds=10,
    clients_per_round=10,
    aggregation_strategy=AggregationStrategy.FED_ADAM
)
Reference: neurenix/federated/server.py:30

Client Selection

config = ServerConfig(
    num_rounds=10,
    clients_per_round=10,        # Number of clients per round
    client_fraction=0.1,         # Or fraction of total clients
    min_clients=2,               # Minimum required clients
    min_sample_size=10,          # Minimum samples per client
    accept_failures=True         # Continue if some clients fail
)
Reference: neurenix/federated/server.py:42

Aggregation Strategies

FedAvg (Federated Averaging)

Weighted average of client models based on number of samples:
from neurenix.federated.strategies import FedAvg

strategy = FedAvg()
aggregated = strategy.aggregate(
    global_model=global_model.state_dict(),
    client_models=client_updates,
    client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:34

FedProx (Federated Proximal)

Handles heterogeneous data with proximal term:
from neurenix.federated.strategies import FedProx

strategy = FedProx(mu=0.01)  # Proximal coefficient
aggregated = strategy.aggregate(
    global_model=global_model.state_dict(),
    client_models=client_updates,
    client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:77

FedNova (Federated Normalized Averaging)

Normalizes client updates to handle varying local epochs:
from neurenix.federated.strategies import FedNova

strategy = FedNova()
aggregated = strategy.aggregate(
    global_model=global_model.state_dict(),
    client_models=client_updates,
    client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:107

FedAdam (Federated Adam)

Server-side adaptive optimization:
from neurenix.federated.strategies import FedAdam

strategy = FedAdam(
    learning_rate=0.01,
    beta1=0.9,
    beta2=0.99,
    epsilon=1e-8
)

aggregated = strategy.aggregate(
    global_model=global_model.state_dict(),
    client_models=client_updates,
    client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:274

FedYogi (Federated Yogi)

Adaptive server optimization with improved convergence:
from neurenix.federated.strategies import FedYogi

strategy = FedYogi(
    learning_rate=0.01,
    beta1=0.9,
    beta2=0.99,
    epsilon=1e-8
)
Reference: neurenix/federated/strategies.py:305

Complete Training Example

import neurenix as nx
from neurenix.federated import (
    FederatedServer, FederatedClient,
    ServerConfig, ClientConfig,
    AggregationStrategy
)

# Create global model
global_model = create_model()
criterion = nx.nn.CrossEntropyLoss()

# Configure and initialize server
server_config = ServerConfig(
    num_rounds=50,
    clients_per_round=10,
    aggregation_strategy=AggregationStrategy.FED_ADAM,
    eval_every=5
)

server = FederatedServer(server_config)
server.initialize(global_model, criterion)

# Create and register clients
for i in range(100):
    client_config = ClientConfig(
        client_id=f"client_{i:03d}",
        batch_size=32,
        epochs=5,
        learning_rate=0.01
    )
    
    client = FederatedClient(client_config)
    client.initialize(
        model=global_model.clone(),
        criterion=criterion,
        train_data=client_train_loaders[i],
        val_data=client_val_loaders[i]
    )
    
    server.add_client(client)

# Train federated model
metrics = server.train(test_data=test_loader)

# Get final model
final_model = server.get_model()
Reference: neurenix/federated/server.py:363

Privacy and Security

Differential Privacy

Add noise to protect individual contributions:
client_config = ClientConfig(
    client_id="client_001",
    differential_privacy=True,
    dp_epsilon=1.0,              # Privacy budget (lower = more privacy)
    dp_delta=1e-5,               # Probability of privacy breach
    dp_mechanism="gaussian"      # Noise mechanism
)
Reference: neurenix/federated/client.py:40

Secure Aggregation

Encrypt model updates before aggregation:
client_config = ClientConfig(
    client_id="client_001",
    secure_aggregation=True      # Enable secure aggregation
)

server_config = ServerConfig(
    secure_aggregation=True      # Server must also enable
)
Reference: neurenix/federated/client.py:39

Model Compression

Reduce communication overhead:
client_config = ClientConfig(
    client_id="client_001",
    compression=True,
    compression_ratio=0.1        # Send only top 10% of gradients
)
Reference: neurenix/federated/client.py:44

Client States

Clients transition through different states during training:
from neurenix.federated.client import ClientState

# Available states
states = [
    ClientState.IDLE,            # Ready for work
    ClientState.TRAINING,        # Training on local data
    ClientState.EVALUATING,      # Evaluating model
    ClientState.SENDING,         # Sending updates to server
    ClientState.RECEIVING        # Receiving global model
]

# Check client state
if client.state == ClientState.IDLE:
    client.train(global_round=0)
Reference: neurenix/federated/client.py:17

Server States

Servers also maintain state during coordination:
from neurenix.federated.server import ServerState

# Available states
states = [
    ServerState.IDLE,
    ServerState.INITIALIZING,
    ServerState.SELECTING_CLIENTS,
    ServerState.DISTRIBUTING,
    ServerState.AGGREGATING,
    ServerState.EVALUATING
]
Reference: neurenix/federated/server.py:20

Advanced Features

Custom Client Selection

# Custom client selection logic
selected_clients = server.select_clients(num_clients=15)

# Or implement custom selection
def select_high_quality_clients(clients, num_clients):
    # Select based on data quality, computation power, etc.
    scores = compute_client_scores(clients)
    return sorted(clients, key=lambda c: scores[c], reverse=True)[:num_clients]
Reference: neurenix/federated/server.py:149

Asynchronous Federated Learning

# Configure server for asynchronous updates
server_config = ServerConfig(
    num_rounds=100,
    clients_per_round=1,         # Update with each client
    accept_failures=True,
    timeout=60.0                 # Wait up to 60s per client
)

Client Weighting

# Weight clients by number of samples
client_weights = {
    "client_001": 1000,  # 1000 training samples
    "client_002": 500,   # 500 training samples
    "client_003": 2000   # 2000 training samples
}

# FedAvg automatically uses sample counts for weighting
Reference: neurenix/federated/server.py:270

Best Practices

1. Handle Non-IID Data

Use FedProx for heterogeneous client data:
config = ClientConfig(
    proximal_mu=0.01,  # Add regularization
    epochs=10          # More local epochs
)

2. Communication Efficiency

# Enable compression and reduce communication frequency
config = ClientConfig(
    compression=True,
    compression_ratio=0.1,
    epochs=10  # More local training between communication
)

3. Privacy Budget Management

# Track privacy budget across rounds
total_epsilon = 10.0
num_rounds = 100
per_round_epsilon = total_epsilon / num_rounds

config = ClientConfig(
    differential_privacy=True,
    dp_epsilon=per_round_epsilon
)

4. Client Dropout Handling

server_config = ServerConfig(
    min_clients=5,              # Minimum required for aggregation
    accept_failures=True,       # Continue if some clients fail
    timeout=120.0               # Allow more time
)

5. Evaluation Strategy

server_config = ServerConfig(
    eval_every=5,               # Evaluate every 5 rounds
    num_rounds=100
)

# Use centralized test set on server
metrics = server.train(test_data=central_test_loader)

Performance Optimization

  1. Batch client updates to reduce server overhead
  2. Use compression for large models
  3. Cache model states to avoid repeated serialization
  4. Parallelize client training when possible
  5. Use quantization for model updates
  6. Implement early stopping for fast clients

Build docs developers (and LLMs) love