Overview
Offline training enables federated learning when participants cannot be simultaneously online. Syft-Flwr’s architecture naturally supports asynchronous patterns through its file-based synchronization and message queuing.Asynchronous FL Architecture
Message Queuing
Syft-Flwr uses a futures-based RPC system that enables offline training:# rpc/protocol.py:5-56
class SyftFlwrRpc(ABC):
"""Protocol for syft-flwr RPC implementations.
This abstraction allows syft-flwr to work with different
transport mechanisms for sending FL messages:
- syft_core: syft_rpc with futures database
- syft_client: P2P File-based RPC via Google Drive sync
"""
@abstractmethod
def send(
self,
to_email: str,
app_name: str,
endpoint: str,
body: bytes,
encrypt: bool = False,
) -> str:
"""Send a message to a recipient.
Returns:
Future ID for tracking the response
"""
@abstractmethod
def get_response(self, future_id: str) -> Optional[bytes]:
"""Get response for a future ID.
Returns:
Response body as bytes, or None if not ready yet
"""
How It Works
Server Sends Task
Server sends training task and stores future ID:
# Server sends message
future_id = rpc.send(
to_email="client@hospital.org",
app_name="my-fl-app",
endpoint="/messages",
body=serialized_task,
encrypt=True
)
# Server can go offline now
Message Persists
Message is stored in file system or futures database:
# SyftBox transport
~/datasites/client@hospital.org/flwr/my-fl-app/futures/{future_id}/
# P2P transport
~/Google Drive/syftbox/flwr/my-fl-app/messages/{future_id}.request
Client Processes Offline
Client comes online later and processes message:
# Client receives message when online
@app.query()
def query(msg: Message, context: Context):
# Train on local data
results = train_model(msg.content)
# Return results
return Message(reply_content, reply_to=msg)
Response Queued
Client’s response is stored for server:
# Response persisted to file system
~/datasites/server@openmined.org/flwr/my-fl-app/futures/{future_id}/response
Offline Training Patterns
Pattern 1: Asynchronous Federated Averaging
Server doesn’t wait for all clients simultaneously:from flwr.server import ServerApp, Grid
from flwr.common import Context, Message, MessageType, RecordDict
import time
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
num_rounds = context.run_config["num-server-rounds"]
min_clients = context.run_config["min-available-clients"]
timeout = context.run_config.get("round-timeout", 3600) # 1 hour
for server_round in range(num_rounds):
# Send tasks to all available clients
all_node_ids = list(grid.get_node_ids())
messages = []
for node_id in all_node_ids:
message = Message(
content=RecordDict({"global_weights": current_weights}),
message_type=MessageType.TRAIN,
dst_node_id=node_id,
group_id=str(server_round),
)
messages.append(message)
# Send all messages (returns immediately)
grid.send_messages(messages)
# Wait for minimum number of responses with timeout
replies = []
start_time = time.time()
while len(replies) < min_clients:
# Poll for responses
new_replies = grid.receive_messages(
group_id=str(server_round),
timeout=1.0
)
replies.extend(new_replies)
# Check timeout
if time.time() - start_time > timeout:
logger.warning(f"Round timeout reached with {len(replies)} replies")
break
time.sleep(5) # Poll every 5 seconds
# Aggregate available responses
if len(replies) >= min_clients:
aggregated_weights = federated_average(replies)
current_weights = aggregated_weights
else:
logger.error(f"Insufficient responses: {len(replies)}/{min_clients}")
Pattern 2: Client-Paced Training
Clients train at their own pace and report when ready:# client_app.py
from flwr.client import ClientApp
from flwr.common import Context, Message
import time
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
# Load local data
data = load_local_dataset()
# Client decides when to train (e.g., during off-peak hours)
wait_for_optimal_time()
# Train model
start_time = time.time()
model = train_on_local_data(msg.content["global_weights"], data)
duration = time.time() - start_time
# Report results with metadata
return Message(
content=RecordDict({
"weights": model.get_weights(),
"num_samples": len(data),
"training_duration": duration,
"timestamp": time.time()
}),
reply_to=msg
)
def wait_for_optimal_time():
"""Wait until off-peak hours for training."""
current_hour = time.localtime().tm_hour
# Train only during night hours (22:00 - 06:00)
if 6 <= current_hour < 22:
wait_seconds = ((22 - current_hour) * 3600)
logger.info(f"Waiting {wait_seconds/3600:.1f} hours for off-peak time")
time.sleep(wait_seconds)
Pattern 3: Opportunistic Training
Clients participate when resources are available:# client_app.py
import psutil
import GPUtil
app = ClientApp()
@app.train()
def train(msg: Message, context: Context):
# Check if resources are available
if not has_sufficient_resources():
logger.info("Insufficient resources, deferring training")
# Message stays in queue, will be retried
time.sleep(300) # Check again in 5 minutes
return None
# Resources available, proceed with training
model = train_model(msg.content)
return Message(model_results, reply_to=msg)
def has_sufficient_resources():
"""Check if client has sufficient resources for training."""
# Check CPU usage
cpu_usage = psutil.cpu_percent(interval=1)
if cpu_usage > 80:
return False
# Check memory availability
memory = psutil.virtual_memory()
if memory.percent > 85:
return False
# Check GPU availability (if applicable)
try:
gpus = GPUtil.getGPUs()
if gpus:
gpu = gpus[0]
if gpu.memoryUtil > 0.8 or gpu.load > 0.8:
return False
except:
pass # No GPU or monitoring not available
return True
File-Based Persistence
Message Storage
Messages are persisted to disk, enabling offline operation:- SyftBox Transport
- P2P Transport
~/datasites/{email}/
└── flwr/{app_name}/
├── futures/
│ ├── {future_id_1}/
│ │ ├── request
│ │ └── response
│ └── {future_id_2}/
│ ├── request
│ └── response
└── events/
└── messages/
~/Google Drive/syftbox/
└── flwr/{app_name}/
└── messages/
├── {future_id_1}.request
├── {future_id_1}.response
├── {future_id_2}.request
└── {future_id_2}.response
Event-Based Processing
Clients process messages when they arrive:# fl_orchestrator/flower_client.py:162-213
def syftbox_flwr_client(
client_app: ClientApp,
context: Context,
app_name: str,
project_dir: Optional[Path] = None,
):
"""Run the Flower ClientApp with SyftBox."""
client, encryption_enabled, syft_flwr_app_name = setup_client(
app_name, project_dir=project_dir
)
# Create events adapter (watches for new messages)
events_watcher = create_events_watcher(
app_name=syft_flwr_app_name,
client=client,
cleanup_expiry="1d",
cleanup_interval="1d",
)
# Register message handler
events_watcher.on_request(
"/messages",
handler=lambda body: processor.process(body),
auto_decrypt=encryption_enabled,
encrypt_reply=encryption_enabled,
)
# Run forever, processing messages as they arrive
events_watcher.run_forever()
Configuration for Offline Training
pyproject.toml Settings
[tool.flwr.app.config]
num-server-rounds = 10
min-available-clients = 2 # Minimum clients needed per round
round-timeout = 3600 # 1 hour timeout per round
client-check-interval = 30 # Check for responses every 30 seconds
[tool.syft_flwr]
transport = "syftbox" # Recommended for offline training
encryption = true # Keep messages secure while queued
Environment Variables
# Enable longer timeouts for offline training
export SYFT_FLWR_MESSAGE_TIMEOUT=7200 # 2 hours
# Configure cleanup intervals
export SYFT_FLWR_CLEANUP_EXPIRY=7d # Keep messages for 7 days
export SYFT_FLWR_CLEANUP_INTERVAL=1d # Clean up daily
Handling Stragglers
Server-Side Straggler Handling
from collections import defaultdict
import time
class AsyncFederatedAveraging:
def __init__(self, min_clients=2, round_timeout=3600):
self.min_clients = min_clients
self.round_timeout = round_timeout
self.pending_rounds = defaultdict(dict)
def aggregate_round(self, grid: Grid, round_num: int, global_weights):
"""Aggregate with straggler handling."""
# Send tasks
all_nodes = list(grid.get_node_ids())
messages = self._create_messages(all_nodes, round_num, global_weights)
grid.send_messages(messages)
# Collect responses with timeout
replies = []
start_time = time.time()
while True:
# Check for new responses
new_replies = grid.receive_messages(
group_id=str(round_num),
timeout=5.0
)
replies.extend(new_replies)
elapsed = time.time() - start_time
# Success conditions
if len(replies) >= len(all_nodes):
logger.info(f"All {len(replies)} clients responded")
break
elif len(replies) >= self.min_clients and elapsed > self.round_timeout:
logger.warning(
f"Timeout reached with {len(replies)}/{len(all_nodes)} responses"
)
break
elif elapsed > self.round_timeout * 2:
logger.error(
f"Hard timeout reached with only {len(replies)} responses"
)
break
time.sleep(5)
# Aggregate available responses
if replies:
return self._aggregate_weights(replies)
else:
logger.error("No responses received, keeping previous weights")
return global_weights
def _aggregate_weights(self, replies):
"""Weighted aggregation based on number of samples."""
total_samples = sum(r.content["num_samples"] for r in replies)
aggregated = None
for reply in replies:
weight = reply.content["num_samples"] / total_samples
client_weights = reply.content["weights"]
if aggregated is None:
aggregated = [w * weight for w in client_weights]
else:
aggregated = [
a + w * weight
for a, w in zip(aggregated, client_weights)
]
return aggregated
Monitoring Offline Training
Tracking Client Participation
from datetime import datetime, timedelta
import json
from pathlib import Path
class ParticipationTracker:
def __init__(self, log_dir: Path):
self.log_dir = log_dir
self.log_dir.mkdir(exist_ok=True)
def log_client_response(self, client_id: str, round_num: int, metadata: dict):
"""Log when a client responds."""
log_file = self.log_dir / f"participation_{client_id}.jsonl"
entry = {
"timestamp": datetime.now().isoformat(),
"round": round_num,
"num_samples": metadata.get("num_samples"),
"training_duration": metadata.get("training_duration"),
"message_latency": metadata.get("latency"),
}
with open(log_file, "a") as f:
f.write(json.dumps(entry) + "\n")
def get_client_stats(self, client_id: str, days: int = 7):
"""Get participation statistics for a client."""
log_file = self.log_dir / f"participation_{client_id}.jsonl"
if not log_file.exists():
return None
cutoff = datetime.now() - timedelta(days=days)
recent_entries = []
with open(log_file) as f:
for line in f:
entry = json.loads(line)
entry_time = datetime.fromisoformat(entry["timestamp"])
if entry_time > cutoff:
recent_entries.append(entry)
if not recent_entries:
return None
return {
"total_rounds": len(recent_entries),
"avg_samples": sum(e["num_samples"] for e in recent_entries) / len(recent_entries),
"avg_duration": sum(e["training_duration"] for e in recent_entries) / len(recent_entries),
"last_seen": recent_entries[-1]["timestamp"],
}
Best Practices
Set Appropriate Timeouts
Set Appropriate Timeouts
Configure timeouts based on expected client availability:
# For daily check-ins
round_timeout = 24 * 3600 # 24 hours
# For weekly participation
round_timeout = 7 * 24 * 3600 # 7 days
Use Message Persistence
Use Message Persistence
Enable file-based persistence to survive restarts:
# Messages automatically persisted by SyftBox
# No additional configuration needed
Implement Graceful Degradation
Implement Graceful Degradation
Handle incomplete rounds:
if len(replies) < min_clients:
logger.warning("Insufficient responses, skipping aggregation")
# Keep previous model weights
continue
Monitor Client Health
Monitor Client Health
Track participation patterns:
tracker = ParticipationTracker(log_dir=Path("./logs"))
for client_id in all_clients:
stats = tracker.get_client_stats(client_id, days=7)
if stats and stats["total_rounds"] < 3:
logger.warning(f"Client {client_id} participated in only {stats['total_rounds']} rounds")
Example: Healthcare FL with Offline Training
# server_app.py
from flwr.server import ServerApp, Grid
from flwr.common import Context
import time
app = ServerApp()
@app.main()
def main(grid: Grid, context: Context) -> None:
"""Healthcare FL with hospitals that are intermittently online."""
num_rounds = 100 # Long-running study
min_hospitals = 3 # Need at least 3 hospitals per round
round_timeout = 24 * 3600 # 24 hours per round
current_weights = initialize_model()
for round_num in range(num_rounds):
logger.info(f"\nRound {round_num + 1}/{num_rounds}")
# Get available hospitals
all_hospitals = list(grid.get_node_ids())
logger.info(f"Total hospitals: {len(all_hospitals)}")
# Send training tasks
messages = create_training_messages(
hospitals=all_hospitals,
round_num=round_num,
global_weights=current_weights
)
grid.send_messages(messages)
# Wait for responses (up to 24 hours)
replies = wait_for_responses(
grid=grid,
round_num=round_num,
min_required=min_hospitals,
timeout=round_timeout
)
logger.info(f"Received {len(replies)}/{len(all_hospitals)} responses")
# Aggregate if we have enough responses
if len(replies) >= min_hospitals:
current_weights = federated_average(replies)
logger.success(f"Updated global model with {len(replies)} hospitals")
else:
logger.warning(f"Insufficient responses, keeping previous model")
# Evaluate every 10 rounds
if (round_num + 1) % 10 == 0:
accuracy = evaluate_model(current_weights)
logger.info(f"Global model accuracy: {accuracy:.2%}")
def wait_for_responses(grid, round_num, min_required, timeout):
"""Wait for minimum number of responses with timeout."""
replies = []
start_time = time.time()
while len(replies) < min_required:
new_replies = grid.receive_messages(
group_id=str(round_num),
timeout=30.0 # Poll every 30 seconds
)
replies.extend(new_replies)
elapsed = time.time() - start_time
if elapsed > timeout:
logger.warning(f"Timeout after {elapsed/3600:.1f} hours")
break
if new_replies:
logger.info(f"Progress: {len(replies)}/{min_required} responses")
time.sleep(60) # Check every minute
return replies
Next Steps
Run Simulations
Test offline patterns locally
Multi-Client Setup
Deploy offline training to production