Skip to main content

Overview

Offline training enables federated learning when participants cannot be simultaneously online. Syft-Flwr’s architecture naturally supports asynchronous patterns through its file-based synchronization and message queuing.

Asynchronous FL Architecture

Message Queuing

Syft-Flwr uses a futures-based RPC system that enables offline training:
# rpc/protocol.py:5-56
class SyftFlwrRpc(ABC):
    """Protocol for syft-flwr RPC implementations.
    
    This abstraction allows syft-flwr to work with different
    transport mechanisms for sending FL messages:
    - syft_core: syft_rpc with futures database
    - syft_client: P2P File-based RPC via Google Drive sync
    """
    
    @abstractmethod
    def send(
        self,
        to_email: str,
        app_name: str,
        endpoint: str,
        body: bytes,
        encrypt: bool = False,
    ) -> str:
        """Send a message to a recipient.
        
        Returns:
            Future ID for tracking the response
        """
    
    @abstractmethod
    def get_response(self, future_id: str) -> Optional[bytes]:
        """Get response for a future ID.
        
        Returns:
            Response body as bytes, or None if not ready yet
        """

How It Works

1

Server Sends Task

Server sends training task and stores future ID:
# Server sends message
future_id = rpc.send(
    to_email="client@hospital.org",
    app_name="my-fl-app",
    endpoint="/messages",
    body=serialized_task,
    encrypt=True
)
# Server can go offline now
2

Message Persists

Message is stored in file system or futures database:
# SyftBox transport
~/datasites/client@hospital.org/flwr/my-fl-app/futures/{future_id}/

# P2P transport
~/Google Drive/syftbox/flwr/my-fl-app/messages/{future_id}.request
3

Client Processes Offline

Client comes online later and processes message:
# Client receives message when online
@app.query()
def query(msg: Message, context: Context):
    # Train on local data
    results = train_model(msg.content)
    
    # Return results
    return Message(reply_content, reply_to=msg)
4

Response Queued

Client’s response is stored for server:
# Response persisted to file system
~/datasites/server@openmined.org/flwr/my-fl-app/futures/{future_id}/response
5

Server Retrieves Later

Server checks for response when online:
# Server polls for response
response = rpc.get_response(future_id)

if response is not None:
    # Response ready, process it
    aggregated_weights = aggregate([response])
else:
    # Still waiting, check later
    pass

Offline Training Patterns

Pattern 1: Asynchronous Federated Averaging

Server doesn’t wait for all clients simultaneously:
from flwr.server import ServerApp, Grid
from flwr.common import Context, Message, MessageType, RecordDict
import time

app = ServerApp()

@app.main()
def main(grid: Grid, context: Context) -> None:
    num_rounds = context.run_config["num-server-rounds"]
    min_clients = context.run_config["min-available-clients"]
    timeout = context.run_config.get("round-timeout", 3600)  # 1 hour
    
    for server_round in range(num_rounds):
        # Send tasks to all available clients
        all_node_ids = list(grid.get_node_ids())
        
        messages = []
        for node_id in all_node_ids:
            message = Message(
                content=RecordDict({"global_weights": current_weights}),
                message_type=MessageType.TRAIN,
                dst_node_id=node_id,
                group_id=str(server_round),
            )
            messages.append(message)
        
        # Send all messages (returns immediately)
        grid.send_messages(messages)
        
        # Wait for minimum number of responses with timeout
        replies = []
        start_time = time.time()
        
        while len(replies) < min_clients:
            # Poll for responses
            new_replies = grid.receive_messages(
                group_id=str(server_round),
                timeout=1.0
            )
            replies.extend(new_replies)
            
            # Check timeout
            if time.time() - start_time > timeout:
                logger.warning(f"Round timeout reached with {len(replies)} replies")
                break
            
            time.sleep(5)  # Poll every 5 seconds
        
        # Aggregate available responses
        if len(replies) >= min_clients:
            aggregated_weights = federated_average(replies)
            current_weights = aggregated_weights
        else:
            logger.error(f"Insufficient responses: {len(replies)}/{min_clients}")

Pattern 2: Client-Paced Training

Clients train at their own pace and report when ready:
# client_app.py
from flwr.client import ClientApp
from flwr.common import Context, Message
import time

app = ClientApp()

@app.train()
def train(msg: Message, context: Context):
    # Load local data
    data = load_local_dataset()
    
    # Client decides when to train (e.g., during off-peak hours)
    wait_for_optimal_time()
    
    # Train model
    start_time = time.time()
    model = train_on_local_data(msg.content["global_weights"], data)
    duration = time.time() - start_time
    
    # Report results with metadata
    return Message(
        content=RecordDict({
            "weights": model.get_weights(),
            "num_samples": len(data),
            "training_duration": duration,
            "timestamp": time.time()
        }),
        reply_to=msg
    )

def wait_for_optimal_time():
    """Wait until off-peak hours for training."""
    current_hour = time.localtime().tm_hour
    
    # Train only during night hours (22:00 - 06:00)
    if 6 <= current_hour < 22:
        wait_seconds = ((22 - current_hour) * 3600)
        logger.info(f"Waiting {wait_seconds/3600:.1f} hours for off-peak time")
        time.sleep(wait_seconds)

Pattern 3: Opportunistic Training

Clients participate when resources are available:
# client_app.py
import psutil
import GPUtil

app = ClientApp()

@app.train()
def train(msg: Message, context: Context):
    # Check if resources are available
    if not has_sufficient_resources():
        logger.info("Insufficient resources, deferring training")
        # Message stays in queue, will be retried
        time.sleep(300)  # Check again in 5 minutes
        return None
    
    # Resources available, proceed with training
    model = train_model(msg.content)
    return Message(model_results, reply_to=msg)

def has_sufficient_resources():
    """Check if client has sufficient resources for training."""
    # Check CPU usage
    cpu_usage = psutil.cpu_percent(interval=1)
    if cpu_usage > 80:
        return False
    
    # Check memory availability
    memory = psutil.virtual_memory()
    if memory.percent > 85:
        return False
    
    # Check GPU availability (if applicable)
    try:
        gpus = GPUtil.getGPUs()
        if gpus:
            gpu = gpus[0]
            if gpu.memoryUtil > 0.8 or gpu.load > 0.8:
                return False
    except:
        pass  # No GPU or monitoring not available
    
    return True

File-Based Persistence

Message Storage

Messages are persisted to disk, enabling offline operation:
~/datasites/{email}/
└── flwr/{app_name}/
    ├── futures/
   ├── {future_id_1}/
   ├── request
   └── response
   └── {future_id_2}/
       ├── request
       └── response
    └── events/
        └── messages/
Messages sync via SyftBox when participants are online.

Event-Based Processing

Clients process messages when they arrive:
# fl_orchestrator/flower_client.py:162-213
def syftbox_flwr_client(
    client_app: ClientApp,
    context: Context,
    app_name: str,
    project_dir: Optional[Path] = None,
):
    """Run the Flower ClientApp with SyftBox."""
    client, encryption_enabled, syft_flwr_app_name = setup_client(
        app_name, project_dir=project_dir
    )
    
    # Create events adapter (watches for new messages)
    events_watcher = create_events_watcher(
        app_name=syft_flwr_app_name,
        client=client,
        cleanup_expiry="1d",
        cleanup_interval="1d",
    )
    
    # Register message handler
    events_watcher.on_request(
        "/messages",
        handler=lambda body: processor.process(body),
        auto_decrypt=encryption_enabled,
        encrypt_reply=encryption_enabled,
    )
    
    # Run forever, processing messages as they arrive
    events_watcher.run_forever()

Configuration for Offline Training

pyproject.toml Settings

[tool.flwr.app.config]
num-server-rounds = 10
min-available-clients = 2      # Minimum clients needed per round
round-timeout = 3600           # 1 hour timeout per round
client-check-interval = 30     # Check for responses every 30 seconds

[tool.syft_flwr]
transport = "syftbox"           # Recommended for offline training
encryption = true              # Keep messages secure while queued

Environment Variables

# Enable longer timeouts for offline training
export SYFT_FLWR_MESSAGE_TIMEOUT=7200  # 2 hours

# Configure cleanup intervals
export SYFT_FLWR_CLEANUP_EXPIRY=7d     # Keep messages for 7 days
export SYFT_FLWR_CLEANUP_INTERVAL=1d   # Clean up daily

Handling Stragglers

Server-Side Straggler Handling

from collections import defaultdict
import time

class AsyncFederatedAveraging:
    def __init__(self, min_clients=2, round_timeout=3600):
        self.min_clients = min_clients
        self.round_timeout = round_timeout
        self.pending_rounds = defaultdict(dict)
    
    def aggregate_round(self, grid: Grid, round_num: int, global_weights):
        """Aggregate with straggler handling."""
        # Send tasks
        all_nodes = list(grid.get_node_ids())
        messages = self._create_messages(all_nodes, round_num, global_weights)
        grid.send_messages(messages)
        
        # Collect responses with timeout
        replies = []
        start_time = time.time()
        
        while True:
            # Check for new responses
            new_replies = grid.receive_messages(
                group_id=str(round_num),
                timeout=5.0
            )
            replies.extend(new_replies)
            
            elapsed = time.time() - start_time
            
            # Success conditions
            if len(replies) >= len(all_nodes):
                logger.info(f"All {len(replies)} clients responded")
                break
            elif len(replies) >= self.min_clients and elapsed > self.round_timeout:
                logger.warning(
                    f"Timeout reached with {len(replies)}/{len(all_nodes)} responses"
                )
                break
            elif elapsed > self.round_timeout * 2:
                logger.error(
                    f"Hard timeout reached with only {len(replies)} responses"
                )
                break
            
            time.sleep(5)
        
        # Aggregate available responses
        if replies:
            return self._aggregate_weights(replies)
        else:
            logger.error("No responses received, keeping previous weights")
            return global_weights
    
    def _aggregate_weights(self, replies):
        """Weighted aggregation based on number of samples."""
        total_samples = sum(r.content["num_samples"] for r in replies)
        
        aggregated = None
        for reply in replies:
            weight = reply.content["num_samples"] / total_samples
            client_weights = reply.content["weights"]
            
            if aggregated is None:
                aggregated = [w * weight for w in client_weights]
            else:
                aggregated = [
                    a + w * weight 
                    for a, w in zip(aggregated, client_weights)
                ]
        
        return aggregated

Monitoring Offline Training

Tracking Client Participation

from datetime import datetime, timedelta
import json
from pathlib import Path

class ParticipationTracker:
    def __init__(self, log_dir: Path):
        self.log_dir = log_dir
        self.log_dir.mkdir(exist_ok=True)
    
    def log_client_response(self, client_id: str, round_num: int, metadata: dict):
        """Log when a client responds."""
        log_file = self.log_dir / f"participation_{client_id}.jsonl"
        
        entry = {
            "timestamp": datetime.now().isoformat(),
            "round": round_num,
            "num_samples": metadata.get("num_samples"),
            "training_duration": metadata.get("training_duration"),
            "message_latency": metadata.get("latency"),
        }
        
        with open(log_file, "a") as f:
            f.write(json.dumps(entry) + "\n")
    
    def get_client_stats(self, client_id: str, days: int = 7):
        """Get participation statistics for a client."""
        log_file = self.log_dir / f"participation_{client_id}.jsonl"
        if not log_file.exists():
            return None
        
        cutoff = datetime.now() - timedelta(days=days)
        recent_entries = []
        
        with open(log_file) as f:
            for line in f:
                entry = json.loads(line)
                entry_time = datetime.fromisoformat(entry["timestamp"])
                if entry_time > cutoff:
                    recent_entries.append(entry)
        
        if not recent_entries:
            return None
        
        return {
            "total_rounds": len(recent_entries),
            "avg_samples": sum(e["num_samples"] for e in recent_entries) / len(recent_entries),
            "avg_duration": sum(e["training_duration"] for e in recent_entries) / len(recent_entries),
            "last_seen": recent_entries[-1]["timestamp"],
        }

Best Practices

Configure timeouts based on expected client availability:
# For daily check-ins
round_timeout = 24 * 3600  # 24 hours

# For weekly participation
round_timeout = 7 * 24 * 3600  # 7 days
Enable file-based persistence to survive restarts:
# Messages automatically persisted by SyftBox
# No additional configuration needed
Handle incomplete rounds:
if len(replies) < min_clients:
    logger.warning("Insufficient responses, skipping aggregation")
    # Keep previous model weights
    continue
Track participation patterns:
tracker = ParticipationTracker(log_dir=Path("./logs"))

for client_id in all_clients:
    stats = tracker.get_client_stats(client_id, days=7)
    if stats and stats["total_rounds"] < 3:
        logger.warning(f"Client {client_id} participated in only {stats['total_rounds']} rounds")

Example: Healthcare FL with Offline Training

# server_app.py
from flwr.server import ServerApp, Grid
from flwr.common import Context
import time

app = ServerApp()

@app.main()
def main(grid: Grid, context: Context) -> None:
    """Healthcare FL with hospitals that are intermittently online."""
    
    num_rounds = 100  # Long-running study
    min_hospitals = 3  # Need at least 3 hospitals per round
    round_timeout = 24 * 3600  # 24 hours per round
    
    current_weights = initialize_model()
    
    for round_num in range(num_rounds):
        logger.info(f"\nRound {round_num + 1}/{num_rounds}")
        
        # Get available hospitals
        all_hospitals = list(grid.get_node_ids())
        logger.info(f"Total hospitals: {len(all_hospitals)}")
        
        # Send training tasks
        messages = create_training_messages(
            hospitals=all_hospitals,
            round_num=round_num,
            global_weights=current_weights
        )
        grid.send_messages(messages)
        
        # Wait for responses (up to 24 hours)
        replies = wait_for_responses(
            grid=grid,
            round_num=round_num,
            min_required=min_hospitals,
            timeout=round_timeout
        )
        
        logger.info(f"Received {len(replies)}/{len(all_hospitals)} responses")
        
        # Aggregate if we have enough responses
        if len(replies) >= min_hospitals:
            current_weights = federated_average(replies)
            logger.success(f"Updated global model with {len(replies)} hospitals")
        else:
            logger.warning(f"Insufficient responses, keeping previous model")
        
        # Evaluate every 10 rounds
        if (round_num + 1) % 10 == 0:
            accuracy = evaluate_model(current_weights)
            logger.info(f"Global model accuracy: {accuracy:.2%}")

def wait_for_responses(grid, round_num, min_required, timeout):
    """Wait for minimum number of responses with timeout."""
    replies = []
    start_time = time.time()
    
    while len(replies) < min_required:
        new_replies = grid.receive_messages(
            group_id=str(round_num),
            timeout=30.0  # Poll every 30 seconds
        )
        replies.extend(new_replies)
        
        elapsed = time.time() - start_time
        if elapsed > timeout:
            logger.warning(f"Timeout after {elapsed/3600:.1f} hours")
            break
        
        if new_replies:
            logger.info(f"Progress: {len(replies)}/{min_required} responses")
        
        time.sleep(60)  # Check every minute
    
    return replies

Next Steps

Run Simulations

Test offline patterns locally

Multi-Client Setup

Deploy offline training to production

Build docs developers (and LLMs) love