Documentation Index
Fetch the complete documentation index at: https://mintlify.com/MilesONerd/neurenix/llms.txt
Use this file to discover all available pages before exploring further.
Overview
Federated learning enables training machine learning models across multiple decentralized clients without sharing raw data. Neurenix provides a complete framework for implementing federated learning with support for various aggregation strategies, secure aggregation, and differential privacy.
Core Components
FederatedClient
The FederatedClient class represents a client participating in federated learning.
from neurenix.federated import FederatedClient, ClientConfig
import neurenix as nx
# Configure client
config = ClientConfig(
client_id="client_001",
batch_size=32,
epochs=5,
learning_rate=0.01,
device="cuda"
)
# Create client
client = FederatedClient(config)
# Initialize with model and data
model = create_model()
criterion = nx.nn.CrossEntropyLoss()
client.initialize(
model=model,
criterion=criterion,
train_data=train_loader,
val_data=val_loader
)
Reference: neurenix/federated/client.py:85
FederatedServer
The FederatedServer class coordinates training across multiple clients.
from neurenix.federated import FederatedServer, ServerConfig, AggregationStrategy
# Configure server
config = ServerConfig(
num_rounds=10,
clients_per_round=10,
client_fraction=0.1,
aggregation_strategy=AggregationStrategy.FED_AVG,
min_clients=2,
eval_every=1,
device="cuda"
)
# Create server
server = FederatedServer(config)
# Initialize with global model
server.initialize(
model=global_model,
criterion=criterion
)
# Register clients
for client in clients:
server.add_client(client)
Reference: neurenix/federated/server.py:106
Client Configuration
Basic Configuration
config = ClientConfig(
client_id="client_001",
batch_size=32,
epochs=5, # Local epochs per round
learning_rate=0.01,
momentum=0.9,
weight_decay=0.0,
max_grad_norm=1.0, # Gradient clipping
device="cuda"
)
Reference: neurenix/federated/client.py:26
FedProx Configuration
FedProx adds a proximal term to handle heterogeneous data:
config = ClientConfig(
client_id="client_001",
batch_size=32,
epochs=5,
learning_rate=0.01,
proximal_mu=0.01, # Proximal term coefficient
device="cuda"
)
Reference: neurenix/federated/client.py:37
Privacy-Preserving Configuration
config = ClientConfig(
client_id="client_001",
batch_size=32,
epochs=5,
learning_rate=0.01,
# Differential Privacy
differential_privacy=True,
dp_epsilon=1.0,
dp_delta=1e-5,
dp_mechanism="gaussian", # 'gaussian' or 'laplace'
# Secure Aggregation
secure_aggregation=True,
# Model Compression
compression=True,
compression_ratio=0.1,
device="cuda"
)
Reference: neurenix/federated/client.py:39
Client Training
Local Training
# Train client for one federated round
metrics = client.train(global_round=0)
print(f"Train loss: {metrics['train_loss']:.4f}")
print(f"Train accuracy: {metrics['train_acc']:.4f}")
print(f"Validation loss: {metrics['val_loss']:.4f}")
print(f"Number of samples: {metrics['n_samples']}")
Reference: neurenix/federated/client.py:128
Model Updates
# Get model update for aggregation
model_update = client.get_model_update()
# Receive global model from server
client.set_model_parameters(global_model_params)
Reference: neurenix/federated/client.py:240
Server Configuration
Aggregation Strategies
from neurenix.federated import AggregationStrategy
# Available strategies
strategies = [
AggregationStrategy.FED_AVG, # Federated Averaging
AggregationStrategy.FED_PROX, # Federated Proximal
AggregationStrategy.FED_NOVA, # Federated Normalized Averaging
AggregationStrategy.FED_OPT, # Federated Optimization
AggregationStrategy.FED_ADAGRAD, # Federated Adagrad
AggregationStrategy.FED_ADAM, # Federated Adam
AggregationStrategy.FED_YOGI, # Federated Yogi
]
config = ServerConfig(
num_rounds=10,
clients_per_round=10,
aggregation_strategy=AggregationStrategy.FED_ADAM
)
Reference: neurenix/federated/server.py:30
Client Selection
config = ServerConfig(
num_rounds=10,
clients_per_round=10, # Number of clients per round
client_fraction=0.1, # Or fraction of total clients
min_clients=2, # Minimum required clients
min_sample_size=10, # Minimum samples per client
accept_failures=True # Continue if some clients fail
)
Reference: neurenix/federated/server.py:42
Aggregation Strategies
FedAvg (Federated Averaging)
Weighted average of client models based on number of samples:
from neurenix.federated.strategies import FedAvg
strategy = FedAvg()
aggregated = strategy.aggregate(
global_model=global_model.state_dict(),
client_models=client_updates,
client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:34
FedProx (Federated Proximal)
Handles heterogeneous data with proximal term:
from neurenix.federated.strategies import FedProx
strategy = FedProx(mu=0.01) # Proximal coefficient
aggregated = strategy.aggregate(
global_model=global_model.state_dict(),
client_models=client_updates,
client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:77
FedNova (Federated Normalized Averaging)
Normalizes client updates to handle varying local epochs:
from neurenix.federated.strategies import FedNova
strategy = FedNova()
aggregated = strategy.aggregate(
global_model=global_model.state_dict(),
client_models=client_updates,
client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:107
FedAdam (Federated Adam)
Server-side adaptive optimization:
from neurenix.federated.strategies import FedAdam
strategy = FedAdam(
learning_rate=0.01,
beta1=0.9,
beta2=0.99,
epsilon=1e-8
)
aggregated = strategy.aggregate(
global_model=global_model.state_dict(),
client_models=client_updates,
client_weights=client_sample_counts
)
Reference: neurenix/federated/strategies.py:274
FedYogi (Federated Yogi)
Adaptive server optimization with improved convergence:
from neurenix.federated.strategies import FedYogi
strategy = FedYogi(
learning_rate=0.01,
beta1=0.9,
beta2=0.99,
epsilon=1e-8
)
Reference: neurenix/federated/strategies.py:305
Complete Training Example
import neurenix as nx
from neurenix.federated import (
FederatedServer, FederatedClient,
ServerConfig, ClientConfig,
AggregationStrategy
)
# Create global model
global_model = create_model()
criterion = nx.nn.CrossEntropyLoss()
# Configure and initialize server
server_config = ServerConfig(
num_rounds=50,
clients_per_round=10,
aggregation_strategy=AggregationStrategy.FED_ADAM,
eval_every=5
)
server = FederatedServer(server_config)
server.initialize(global_model, criterion)
# Create and register clients
for i in range(100):
client_config = ClientConfig(
client_id=f"client_{i:03d}",
batch_size=32,
epochs=5,
learning_rate=0.01
)
client = FederatedClient(client_config)
client.initialize(
model=global_model.clone(),
criterion=criterion,
train_data=client_train_loaders[i],
val_data=client_val_loaders[i]
)
server.add_client(client)
# Train federated model
metrics = server.train(test_data=test_loader)
# Get final model
final_model = server.get_model()
Reference: neurenix/federated/server.py:363
Privacy and Security
Differential Privacy
Add noise to protect individual contributions:
client_config = ClientConfig(
client_id="client_001",
differential_privacy=True,
dp_epsilon=1.0, # Privacy budget (lower = more privacy)
dp_delta=1e-5, # Probability of privacy breach
dp_mechanism="gaussian" # Noise mechanism
)
Reference: neurenix/federated/client.py:40
Secure Aggregation
Encrypt model updates before aggregation:
client_config = ClientConfig(
client_id="client_001",
secure_aggregation=True # Enable secure aggregation
)
server_config = ServerConfig(
secure_aggregation=True # Server must also enable
)
Reference: neurenix/federated/client.py:39
Model Compression
Reduce communication overhead:
client_config = ClientConfig(
client_id="client_001",
compression=True,
compression_ratio=0.1 # Send only top 10% of gradients
)
Reference: neurenix/federated/client.py:44
Client States
Clients transition through different states during training:
from neurenix.federated.client import ClientState
# Available states
states = [
ClientState.IDLE, # Ready for work
ClientState.TRAINING, # Training on local data
ClientState.EVALUATING, # Evaluating model
ClientState.SENDING, # Sending updates to server
ClientState.RECEIVING # Receiving global model
]
# Check client state
if client.state == ClientState.IDLE:
client.train(global_round=0)
Reference: neurenix/federated/client.py:17
Server States
Servers also maintain state during coordination:
from neurenix.federated.server import ServerState
# Available states
states = [
ServerState.IDLE,
ServerState.INITIALIZING,
ServerState.SELECTING_CLIENTS,
ServerState.DISTRIBUTING,
ServerState.AGGREGATING,
ServerState.EVALUATING
]
Reference: neurenix/federated/server.py:20
Advanced Features
Custom Client Selection
# Custom client selection logic
selected_clients = server.select_clients(num_clients=15)
# Or implement custom selection
def select_high_quality_clients(clients, num_clients):
# Select based on data quality, computation power, etc.
scores = compute_client_scores(clients)
return sorted(clients, key=lambda c: scores[c], reverse=True)[:num_clients]
Reference: neurenix/federated/server.py:149
Asynchronous Federated Learning
# Configure server for asynchronous updates
server_config = ServerConfig(
num_rounds=100,
clients_per_round=1, # Update with each client
accept_failures=True,
timeout=60.0 # Wait up to 60s per client
)
Client Weighting
# Weight clients by number of samples
client_weights = {
"client_001": 1000, # 1000 training samples
"client_002": 500, # 500 training samples
"client_003": 2000 # 2000 training samples
}
# FedAvg automatically uses sample counts for weighting
Reference: neurenix/federated/server.py:270
Best Practices
1. Handle Non-IID Data
Use FedProx for heterogeneous client data:
config = ClientConfig(
proximal_mu=0.01, # Add regularization
epochs=10 # More local epochs
)
2. Communication Efficiency
# Enable compression and reduce communication frequency
config = ClientConfig(
compression=True,
compression_ratio=0.1,
epochs=10 # More local training between communication
)
3. Privacy Budget Management
# Track privacy budget across rounds
total_epsilon = 10.0
num_rounds = 100
per_round_epsilon = total_epsilon / num_rounds
config = ClientConfig(
differential_privacy=True,
dp_epsilon=per_round_epsilon
)
4. Client Dropout Handling
server_config = ServerConfig(
min_clients=5, # Minimum required for aggregation
accept_failures=True, # Continue if some clients fail
timeout=120.0 # Allow more time
)
5. Evaluation Strategy
server_config = ServerConfig(
eval_every=5, # Evaluate every 5 rounds
num_rounds=100
)
# Use centralized test set on server
metrics = server.train(test_data=central_test_loader)
- Batch client updates to reduce server overhead
- Use compression for large models
- Cache model states to avoid repeated serialization
- Parallelize client training when possible
- Use quantization for model updates
- Implement early stopping for fast clients