What You’ll Build
You’ll create a diabetes prediction model trained collaboratively across multiple data owners without centralizing their data. The complete project structure will look like:Prerequisites
- Python 3.12 or later
- Basic understanding of PyTorch and machine learning
- Access to at least 2 machines or environments (for data owners)
mkdir -p fl-diabetes-prediction/fl_diabetes_prediction
cd fl-diabetes-prediction
touch fl_diabetes_prediction/__init__.py
touch fl_diabetes_prediction/client_app.py
touch fl_diabetes_prediction/server_app.py
touch fl_diabetes_prediction/task.py
touch pyproject.toml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict
class Net(nn.Module):
def __init__(self, input_dim=6):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(input_dim, 32),
nn.BatchNorm1d(32),
nn.LeakyReLU(0.1),
nn.Dropout(0.2),
)
self.layer2 = nn.Sequential(
nn.Linear(32, 24),
nn.BatchNorm1d(24),
nn.LeakyReLU(0.1),
nn.Dropout(0.25),
)
self.layer3 = nn.Sequential(
nn.Linear(24, 16),
nn.BatchNorm1d(16),
nn.LeakyReLU(0.1)
)
self.output_layer = nn.Sequential(
nn.Linear(16, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.output_layer(x)
return x
def train(model, train_loader, local_epochs=1):
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
model.train()
for epoch in range(local_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
def evaluate(model, data_loader):
model.eval()
criterion = nn.BCELoss()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in data_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
predicted = (outputs > 0.5).float()
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(data_loader.dataset)
epoch_acc = correct / total
return epoch_loss, epoch_acc
def get_weights(model):
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_weights(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
Add the data loading function in
task.py. This is the key difference from standard Flower—it loads data from SyftBox:def load_syftbox_dataset():
"""Load dataset from SyftBox private data directory"""
import pandas as pd
from syft_flwr.utils import get_syftbox_dataset_path
# Get the private dataset path set by SyftBox
data_dir = get_syftbox_dataset_path()
# Load train and test data
train_df = pd.read_csv(data_dir / "train.csv")
test_df = pd.read_csv(data_dir / "test.csv")
# Process and return DataLoaders
return dataset_processing(train_df, test_df)
The
get_syftbox_dataset_path() function retrieves the path to private data that only the data owner can access. This ensures data never leaves the owner’s machine.from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from fl_diabetes_prediction.task import (
Net,
evaluate,
get_weights,
set_weights,
train,
load_syftbox_dataset,
)
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, testloader):
self.net = net
self.trainloader = trainloader
self.testloader = testloader
def fit(self, parameters, config):
set_weights(self.net, parameters)
train(self.net, self.trainloader)
return get_weights(self.net), len(self.trainloader), {}
def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = evaluate(self.net, self.testloader)
return loss, len(self.testloader), {"accuracy": accuracy}
def client_fn(context: Context):
# Load the private dataset from SyftBox
train_loader, test_loader = load_syftbox_dataset()
net = Net()
return FlowerClient(net, train_loader, test_loader).to_client()
app = ClientApp(client_fn=client_fn)
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from syft_flwr.strategy import FedAvgWithModelSaving
from pathlib import Path
import os
from fl_diabetes_prediction.task import Net, get_weights
def weighted_average(metrics):
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples)}
def server_fn(context: Context):
# Initialize the model
net = Net()
params = ndarrays_to_parameters(get_weights(net))
# Set up model save path
output_dir = os.getenv("OUTPUT_DIR", Path.home() / ".syftbox/rds/")
save_path = Path(output_dir) / "weights"
# Configure the aggregation strategy
strategy = FedAvgWithModelSaving(
save_path=save_path,
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
min_fit_clients=2,
min_evaluate_clients=2,
initial_parameters=params,
evaluate_metrics_aggregation_fn=weighted_average,
)
num_rounds = context.run_config["num-server-rounds"]
config = ServerConfig(num_rounds=num_rounds)
return ServerAppComponents(config=config, strategy=strategy)
app = ServerApp(server_fn=server_fn)
FedAvgWithModelSaving is a custom strategy that saves the global model to disk after each training round, making it easy to track progress and recover from failures.[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "fl-diabetes-prediction"
version = "1.0.0"
requires-python = ">=3.12"
description = "Federated Learning for Diabetes Prediction"
license = "Apache-2.0"
dependencies = [
"flwr-datasets>=0.5.0",
"torch>=2.8.0",
"imblearn",
"pandas",
"scikit-learn==1.6.1",
"loguru",
"syft_flwr",
]
[tool.flwr.app]
publisher = "YourName"
[tool.flwr.app.components]
serverapp = "fl_diabetes_prediction.server_app:app"
clientapp = "fl_diabetes_prediction.client_app:app"
[tool.flwr.app.config]
num-server-rounds = 3
min-available-clients = 2
min-fit-clients = 2
min-evaluate-clients = 2
fraction-fit = 1.0
fraction-evaluate = 1.0
import syft_flwr
from pathlib import Path
project_path = Path("./fl-diabetes-prediction")
aggregator_email = "ds@openmined.org" # Data scientist email
datasite_emails = ["do1@openmined.org", "do2@openmined.org"] # Data owner emails
syft_flwr.bootstrap(
project_path,
aggregator=aggregator_email,
datasites=datasite_emails
)
import syft_rds as sy
# Connect to data owner's datasite
do1_client = sy.init_session(
host="do1@openmined.org",
email="ds@openmined.org"
)
# Submit the job
do1_client.job.submit(
name="fl-diabetes-prediction",
user_code_path=project_path,
dataset_name="pima-indians-diabetes-database",
entrypoint="main.py",
)
# Submit to yourself to run the server
ds_client = sy.init_session(
host="ds@openmined.org",
email="ds@openmined.org"
)
job = ds_client.job.submit(
name="fl-diabetes-prediction-server",
user_code_path=project_path,
entrypoint="main.py",
)
# Approve and run
ds_client.job.approve(job)
ds_client.run_private(job, blocking=True)
weights/
├── parameters_round_1.safetensors
├── parameters_round_2.safetensors
└── parameters_round_3.safetensors
Key Differences from Standard Flower
Syft-Flwr requires only minimal changes to a standard Flower project:- Data Loading: Use
load_syftbox_dataset()instead of loading from public datasets - Bootstrap Step: Run
syft_flwr.bootstrap()to configure participants - Communication: Messages are exchanged via file sync instead of network connections
What’s Next?
- Learn how to run FL in Google Colab
- Set up a local SyftBox environment
- Implement custom aggregation strategies
Common Issues
FileNotFoundError: DATA_DIR does not exist
FileNotFoundError: DATA_DIR does not exist
This means the
DATA_DIR environment variable is not set. Make sure:- You’re running the code through SyftBox job execution
- The dataset is properly registered with SyftBox
- You’re using
load_syftbox_dataset()in your client code
No clients connecting to server
No clients connecting to server
Check that:
- Data owners have approved the job requests
- All participants are running their respective code
- The app_name in
pyproject.tomlmatches across all participants - SyftBox is running and syncing files properly
Model weights not saving
Model weights not saving
Verify:
- The
OUTPUT_DIRenvironment variable points to a writable location - You’re using
FedAvgWithModelSavingstrategy - The save path directory exists and has write permissions