Personalized Federated Learning: One Global Model, Many Local Needs
The fundamental premise of federated learning is to train a single global model across diverse devices. But what happens when "one size fits all" doesn't fit anyone particularly well?
The personalization dilemma: A global keyboard prediction model trained on millions of devices might be mediocre for everyone—users who text in multiple languages, users with specialized vocabularies (medical, legal), or users with unique writing styles all suffer from a lowest-common-denominator model.
This post explores how personalized federated learning enables Octomil to deliver both collective intelligence and individual adaptation.
The Case for Personalization
Real-World Heterogeneity
Consider a mobile health app for diabetes management:
- Patient A (Type 1, athlete): Needs aggressive insulin adjustments for exercise
- Patient B (Type 2, sedentary): Requires different meal-based recommendations
- Patient C (gestational): Has unique pregnancy-related patterns
A single global model averages across these populations and serves none well. Personalized FL lets each device maintain a local model variant optimized for its specific context while still benefiting from global knowledge.
Fairness Concerns
Standard FL can inadvertently harm minority groups:
- Data bias: Devices with common patterns dominate the global model
- Accuracy disparity: Minority devices see lower accuracy (the "worst-group" problem)
- Example: A speech recognition model trained primarily on American English fails on regional accents
Personalization addresses fairness by ensuring every device gets a model that works well for its data distribution.
Personalization Strategies
1. Fine-Tuning (The Simple Baseline)
The most straightforward approach:
- Train a global model via standard FL
- Each device fine-tunes the model on its local data
- Never upload personalized weights (privacy-preserving)
import octomil
# Global training
client = octomil.OctomilClient(project_id="diabetes-prediction")
global_model = client.train(
model=base_model,
rounds=50
)
# Local personalization
personalized_model = client.fine_tune(
global_model=global_model,
local_data=my_device_data,
epochs=5,
upload=False # Keep personalized model local
)
Pros: Simple, privacy-preserving (personalized weights never leave device) Cons: No global feedback from personalization, requires sufficient local data
2. Explicit Personalization with Scafflix
Virginia Smith's research group has extensively studied personalization-aware FL algorithms. Scafflix1 combines explicit personalization layers with communication-efficient local training:
Key idea: Separate model into:
- Global layers: Shared feature extractors (trained via FL)
- Personal layers: Device-specific prediction heads (local only)
# Scafflix-style architecture in Octomil
import torch.nn as nn
from octomil import PersonalizedModel
class PersonalizedDiabetesModel(PersonalizedModel):
def __init__(self):
super().__init__()
# Global layers (shared across devices)
self.global_layers = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
# Personal layers (device-specific)
self.personal_head = nn.Linear(64, num_classes)
def forward(self, x):
shared_features = self.global_layers(x)
return self.personal_head(shared_features)
# Octomil handles synchronization automatically
client.train_personalized(
model=PersonalizedDiabetesModel(),
sync_layers=["global_layers"], # Only sync global layers
local_layers=["personal_head"] # Keep personal layers local
)
Scafflix provides:
- 2× communication reduction compared to full model sync
- Improved local accuracy via personalized heads
- Control variates to correct for data heterogeneity
3. Maximizing Global Model Appeal
Not every device needs personalization—some benefit more from the global model. Cho et al.2 introduce the concept of global model appeal:
Goal: Maximize the fraction of devices for which the global model performs well, while personalizing for the rest.
Octomil implements this via appeal-aware selection:
# Octomil automatically identifies devices needing personalization
client = octomil.OctomilClient(
project_id="keyboard-prediction",
personalization_strategy="appeal-aware"
)
# After global training, Octomil measures per-device accuracy
metrics = client.evaluate_global_model()
# Devices with low accuracy get personalization recommendations
if metrics.local_accuracy < metrics.global_accuracy - threshold:
print("Personalization recommended for this device")
client.enable_personalization()
4. Federated Learning with Model Pruning
For resource-constrained devices, FedP33 (Personalized and Privacy-friendly Federated network Pruning) combines personalization with model compression:
Approach:
- Each device learns a personalized sparse subnetwork
- Devices with different resource constraints can use different sparsity levels
- Privacy-preserving: Only pruning masks (not weights) are shared
# Personalized pruning in Octomil
client.train_personalized(
model=my_model,
pruning=True,
target_sparsity=0.7, # 70% of weights pruned
personalize_sparsity=True # Each device learns custom masks
)
Fairness-Aware Federated Learning
The Worst-Group Problem
Standard FedAvg minimizes average loss across devices, which can leave minority groups behind. Smith's work on bounded group loss4 provides fairness guarantees:
Bounded Group Loss: Ensure every predefined group (e.g., demographic, language) achieves loss below a threshold.
# Fair FL in Octomil
client = octomil.OctomilClient(
project_id="speech-recognition",
fairness_constraint="bounded-group-loss",
groups=["us-english", "uk-english", "indian-english"],
max_group_loss=0.15 # No group can exceed 15% error rate
)
Octomil monitors per-group metrics in real-time and reweights updates from underperforming groups.
Tilted Loss for Fairness
Li et al.5 introduce tilted empirical risk minimization, which interpolates between average-case (FedAvg) and worst-case (minimax) objectives:
Tilted loss (parameter t controls tradeoff):
- t = 0: Standard average loss (FedAvg)
- t → ∞: Minimax (optimize for worst device)
- t ∈ (0, ∞): Smooth tradeoff
# Tilted loss for fairness-accuracy tradeoff
client.train(
model=my_model,
loss_function="tilted",
tilt_parameter=2.0 # Balanced fairness
)
Robustness Under Distribution Shift
Setlur et al.6 show that prompting improves worst-group robustness in foundation models, with implications for personalized FL:
Key finding: Carefully designed prompts can surface latent personalization in pre-trained models without fine-tuning.
Octomil's approach for LLM-based applications:
# Prompt-based personalization for federated LLMs
client = octomil.OctomilClient(
model_type="llm",
personalization_method="prompt-tuning"
)
# Each device learns a small prompt embedding
# Global LLM remains frozen, only prompts are personalized
personalized_prompt = client.train_prompt(
base_model="llama-3-8b",
local_data=device_data,
prompt_length=20 # 20 token prompt
)
Octomil's Personalization Framework
Octomil provides a unified API for personalization:
import octomil
# 1. Initialize with personalization enabled
client = octomil.OctomilClient(
project_id="my-personalized-fl",
personalization_enabled=True
)
# 2. Define personalization strategy
client.set_personalization_strategy(
method="scafflix", # or "fine-tune", "pruning", "prompt"
personal_layers=["head"], # Which layers to personalize
fairness_constraint="bounded-group-loss",
groups_file="device_groups.json"
)
# 3. Train with automatic personalization
client.train(
model=my_model,
rounds=50,
local_epochs=5
)
# 4. Evaluate personalization benefit
metrics = client.evaluate_personalization()
print(f"Global accuracy: {metrics.global_acc}")
print(f"Personalized accuracy: {metrics.personalized_acc}")
print(f"Worst-group improvement: {metrics.worst_group_improvement}")
When to Personalize?
Use personalization when:
- High data heterogeneity (device distributions differ significantly)
- Fairness is critical (healthcare, finance, public services)
- Devices have sufficient local data (>100 examples for fine-tuning)
- Privacy constraints allow local model updates
Stick with global models when:
- Data is relatively homogeneous (e.g., identical IoT sensors)
- Devices have minimal local data
- Simplicity is paramount (global model is easier to maintain)
Research Frontiers
Active areas in personalized FL:
- Multi-task personalization: Jointly optimize multiple related tasks per device7
- Meta-learning for personalization: Learn initialization that adapts quickly to new devices
- Hierarchical personalization: Group-level + device-level personalization
- Personalization under differential privacy: Balancing privacy and adaptation
Getting Started
Try personalized FL in Octomil:
pip install octomil
octomil init my-personalized-project --personalization
octomil train --personalization scafflix --fairness bounded-group-loss
See our Advanced FL Concepts guide for detailed examples.
References
Additional foundational papers:
- Ditto: Li, T., Hu, S., Beirami, A., & Smith, V. (2021). Ditto: Fair and robust federated learning through personalization. ICML 2021. arXiv:2012.04221
- pFedMe: Dinh, C. T., Tran, N., & Nguyen, J. (2020). Personalized federated learning with Moreau envelopes. NeurIPS 2020. arXiv:2006.08848