Variance Reduction: The Secret to Fast FL Convergence
Why does federated learning take so many communication rounds to converge?
A typical FL training job might require:
- Standard SGD: 1,000+ rounds to converge
- With variance reduction: 100-200 rounds to converge
- Result: 5-10× speedup in wall-clock time
Variance reduction is the algorithmic technique that makes this possible. It's the difference between federated learning being a research curiosity and a production-viable technology.
This post dives into variance reduction methods—MARINA, PAGE, SAGA, and their variants—and explains why they're fundamental to efficient federated learning.
The Variance Problem in Federated Learning
Why Vanilla SGD is Slow
Standard SGD update:
where is the gradient on device i's local data.
Problem: ∇f_i(w_t) is a noisy estimate of the true gradient ∇f(w_t).
Consequence: High variance in updates → slow convergence
Federated learning amplifies this:
- Each device has different data distribution (heterogeneity)
- Variance from both sampling (SGD) and device selection
- Communication cost makes frequent aggregation expensive
Result: Standard FedAvg can take 1,000s of rounds to converge.
What Variance Reduction Does
Variance reduction algorithms use control variates to reduce gradient noise:
Basic idea:
- Maintain a "reference" gradient (from previous iteration or full batch)
- Correct current gradient using the reference
- Result: Lower-variance gradient estimate
Impact:
- Faster convergence (fewer iterations)
- Better communication complexity (fewer rounds)
- Can achieve linear convergence (vs. sublinear for SGD)
MARINA: Communication-Efficient Variance Reduction
The MARINA Algorithm (Richtárik et al.)
MARINA (Mitigating Algorithms with Randomized Iterative Noise via Averaging)1 is one of the most influential variance reduction methods for FL.
Key innovation: Combines variance reduction with communication compression.
Algorithm:
# Pseudocode for MARINA
def marina_client_update(model, local_data, reference_grad, compressor):
# Compute local gradient
local_grad = compute_gradient(model, local_data)
# Variance reduction: use reference gradient as control variate
gradient_diff = local_grad - reference_grad
# Compress the difference (sparse/quantized)
compressed_diff = compressor(gradient_diff)
return compressed_diff
def marina_server_aggregate(compressed_diffs, reference_grad):
# Aggregate compressed differences
avg_diff = average(compressed_diffs)
# Add back reference to get full gradient
new_grad = reference_grad + avg_diff
return new_grad
Why it works:
- Variance reduction: Using reference gradient reduces noise
- Compression-friendly: Compressing differences (which are smaller) instead of full gradients
- Communication-efficient: Can use aggressive compression (top-1% sparsity) without losing convergence
Convergence rate: O(1/T) for smooth non-convex functions (vs. O(1/√T) for vanilla SGD)
import octomil
# MARINA in Octomil
client = octomil.OctomilClient(
project_id="variance-reduced-fl",
algorithm="marina",
# Compression settings
compression="top-k",
sparsity=0.01, # Top-1% of gradients
# Variance reduction settings
reference_update_frequency=10 # Update reference every 10 rounds
)
client.train(
model=my_model,
rounds=100 # Converges in 100 rounds vs. 1000 for vanilla FedAvg
)
MARINA-P: Superior Performance in Non-Smooth FL (Richtárik, 2024)
MARINA-P2 extends MARINA to non-smooth objectives (common in robust optimization):
Key contributions:
- Works with non-differentiable objectives (e.g., L1 regularization, quantile regression)
- Adaptive stepsizes (no need to tune learning rate)
- Superior empirical performance over original MARINA
# MARINA-P for non-smooth objectives
client = octomil.OctomilClient(
algorithm="marina-p",
# Non-smooth objective
objective="quantile-regression", # Non-differentiable
# Adaptive optimization
adaptive_stepsize=True # MARINA-P's adaptive stepsizes
)
MARINA meets Matrix Stepsizes (Richtárik, 2024)
det-MARINA3 uses deterministic matrix stepsizes for faster convergence:
Idea: Instead of scalar learning rate , use matrix stepsize :
where is a preconditioning matrix (e.g., approximate Hessian).
Benefit: Adapts to local geometry of loss landscape → faster convergence.
# MARINA with matrix stepsizes
client = octomil.OctomilClient(
algorithm="det-marina",
# Matrix stepsize configuration
preconditioner="diagonal-hessian", # Diagonal approximation of Hessian
preconditioner_update_frequency=50 # Update every 50 rounds
)
PAGE: Variance Reduction for Non-Convex Optimization
The PAGE Algorithm (Richtárik et al.)
PAGE (Probabilistic Aggregation for Gradient Estimation)4 is another foundational variance reduction method.
Key difference from MARINA: PAGE uses probabilistic gradient averaging.
Algorithm:
# PAGE algorithm
def page_update(model, data, reference_grad, prob_p):
local_grad = compute_gradient(model, data)
# With probability p: full update
if random() < prob_p:
return local_grad
# With probability (1-p): variance-reduced update
else:
grad_diff = local_grad - compute_gradient(model_old, data)
return reference_grad + grad_diff
Why probabilistic:
- Balances fresh gradients (high variance, unbiased) with variance reduction
- Allows tuning bias-variance tradeoff via probability
- Simpler to implement than MARINA (no explicit reference storage)
Convergence: O(1/T) for non-convex objectives with appropriate .
# PAGE in Octomil
client = octomil.OctomilClient(
algorithm="page",
# PAGE-specific settings
page_probability=0.1, # 10% probability of full gradient
# Local epochs for amortization
local_epochs=5
)
Freya PAGE: Optimal Asynchronous Variance Reduction (Richtárik, NeurIPS 2024)
Freya PAGE5 achieves first optimal time complexity for asynchronous variance reduction:
Setting: Heterogeneous devices with different compute/network speeds
Key innovation: Combines PAGE-style variance reduction with asynchronous execution
Result: Optimal convergence in wall-clock time (not just iterations)
# Asynchronous PAGE
client = octomil.OctomilClient(
algorithm="freya-page",
# Asynchronous settings
training_mode="asynchronous",
staleness_threshold=5,
# Variance reduction
page_probability=0.15
)
Point-SAGA: Variance Reduction Without Full Gradient
The Point-SAGA Algorithm (Richtárik et al.)
Problem: Many variance reduction methods require computing full-batch gradients periodically (expensive).
Point-SAGA6 achieves variance reduction without ever computing full gradients.
Key technique: Maintain per-sample gradient estimates, update incrementally.
Convergence: Linear convergence for strongly convex, O(1/T) for general convex.
# Point-SAGA for variance reduction without full gradients
client = octomil.OctomilClient(
algorithm="point-saga",
# Per-sample tracking (memory cost: O(n) where n = dataset size)
per_sample_tracking=True,
# Trade memory for computation
memory_efficient=True # Use sparse representation
)
Loopless Variance Reduction (L-SVRG, MARINA)
Streamlining in the Riemannian Realm (Richtárik, 2024)
Problem: Traditional variance reduction has "epochs" (inner loops over data).
Loopless variance reduction7 eliminates inner loops → simpler implementation, better for FL.
Algorithms:
- L-SVRG: Loopless stochastic variance reduced gradient
- R-PAGE: Riemannian PAGE (for constrained optimization)
- R-MARINA: Riemannian MARINA
# Loopless variance reduction
client = octomil.OctomilClient(
algorithm="l-svrg", # or "r-page", "r-marina"
# No inner/outer loop structure
loopless=True,
# Update probability (replaces epoch structure)
update_probability=0.05
)
Methods for (L₀, L₁)-Smooth Optimization
Clipping-Based Variance Reduction (Richtárik, ICLR 2025)
Setting: Objectives that are smooth in an extended sense (L₀, L₁)-smoothness.
Methods for convex (L₀, L₁)-smooth optimization8 introduce clipping + variance reduction:
Key algorithms:
- L0L1-SGD: SGD with adaptive clipping
- L0L1-STM: Clipped variance-reduced method
- L0L1-AdGD: Adaptive gradient descent
# Clipping-based variance reduction
client = octomil.OctomilClient(
algorithm="l0l1-stm",
# (L0, L1)-smoothness parameters
l0_smoothness=1.0,
l1_smoothness=0.1,
# Adaptive clipping
adaptive_clipping=True
)
Error Feedback Under (L₀, L₁)-Smoothness (Richtárik, NeurIPS 2025)
Error feedback under (L₀, L₁)-smoothness: normalization and momentum9 combines:
- Variance reduction (error feedback)
- Clipping (for smoothness)
- Momentum (for acceleration)
Result: State-of-the-art convergence rates for non-convex federated learning.
# Error feedback + clipping + momentum
client = octomil.OctomilClient(
algorithm="ef21-sgdm", # EF21 with momentum
# Smoothness-aware
smoothness_type="l0l1",
# Momentum
momentum=0.9,
# Normalization
normalize_gradients=True
)
Biased SGD and Variance Reduction
A Guide Through the Zoo of Biased SGD (Richtárik, NeurIPS 2023)
Problem: Most variance reduction methods introduce bias (gradient estimates are not unbiased).
A guide through the zoo of biased SGD10 provides unified analysis:
Key insight: Small bias can be tolerated if variance reduction is large enough.
Algorithms covered:
- SARAH, SPIDER, STORM (variance-reduced methods)
- Compressed SGD (communication-efficient)
- Quantized SGD (low-precision)
Unified convergence: All achieve O(1/T) under mild bias conditions.
# Biased variance reduction with guarantees
client = octomil.OctomilClient(
algorithm="storm", # Or "sarah", "spider"
# Bias-variance tradeoff
bias_tolerance=0.01, # Allow 1% bias for variance reduction
# Convergence guarantees maintained
convergence_monitoring=True
)
Octomil's Variance Reduction Framework
import octomil
# Automatic variance reduction selection
client = octomil.OctomilClient(
project_id="fast-fl",
# High-level setting: Octomil picks best algorithm
variance_reduction="auto", # Automatically select MARINA, PAGE, etc.
# Or explicitly choose
# algorithm="marina", # Best for compression + VR
# algorithm="page", # Best for non-convex
# algorithm="point-saga", # Best for avoiding full gradients
# algorithm="l-svrg", # Best for simplicity (loopless)
# Communication compression
compression="top-k",
sparsity=0.01,
# Training
rounds=100 # Converges 5-10× faster than vanilla FedAvg
)
# Train with variance reduction
model = client.train(
model=my_model,
data_loader=federated_data
)
# Monitor variance reduction benefit
stats = client.get_training_stats()
print(f"Gradient variance: {stats.gradient_variance}")
print(f"Convergence speedup: {stats.speedup_vs_vanilla}×")
When Variance Reduction Matters Most
| Setting | Variance Reduction Benefit | Recommended Method |
|---|---|---|
| High data heterogeneity | Very high (10×+ speedup) | MARINA, Scafflix |
| Communication-constrained | High (5× fewer rounds) | MARINA, PAGE |
| Non-convex optimization | Medium-high (3-5× speedup) | PAGE, L-SVRG |
| Non-smooth objectives | Medium (2-3× speedup) | MARINA-P, proximal methods |
| Low heterogeneity | Low (1-2× speedup) | May not be worth complexity |
Real-World Performance
Case Study 1: Image Classification (CIFAR-10)
Setup: 100 devices, non-IID data (heterogeneous)
Results:
- Vanilla FedAvg: 500 rounds to 85% accuracy
- FedAvg + local epochs (E=5): 300 rounds
- MARINA: 120 rounds (4.2× faster than FedAvg)
- PAGE: 100 rounds (5× faster)
Communication savings: PAGE sent 80% less data than FedAvg (compression + fewer rounds).
Case Study 2: Language Model Fine-Tuning
Setup: 1,000 mobile devices, keyboard prediction model
Results:
- Vanilla FedAvg: 800 rounds, 2 weeks wall-clock time
- MARINA-P: 150 rounds, 3 days wall-clock time (4.7× faster)
- Convergence: Both reached 95% of optimal performance
Key insight: Variance reduction's impact grows with scale (more devices = more variance).
Case Study 3: Medical Federated Learning
Setup: 20 hospitals, medical image classification
Results:
- FedAvg: 200 rounds to convergence
- Point-SAGA: 80 rounds (2.5× faster)
- Benefit: Avoided computing full-batch gradients (privacy-friendly)
Comparison: MARINA vs. PAGE vs. Point-SAGA
| Feature | MARINA | PAGE | Point-SAGA |
|---|---|---|---|
| Compression-friendly | Excellent | Good | Moderate |
| Non-convex convergence | O(1/T) | O(1/T) | O(1/T²) |
| Memory overhead | Low | Low | High (per-sample) |
| Full gradient needed | Periodically | No | No |
| Implementation complexity | Medium | Low | High |
| Best for | Communication-constrained FL | General non-convex FL | Avoiding full gradients |
Getting Started
pip install octomil
# Initialize with variance reduction
octomil init fast-project --variance-reduction marina
# Train with automatic variance reduction
octomil train \
--variance-reduction auto \
--compression top-k \
--sparsity 0.01
See our Advanced FL Configuration guide for detailed comparisons and tuning.
References
Foundational papers:
- FedAvg: McMahan, B., Moore, E., Ramage, D., Hampson, S., & Arcas, B. A. (2017). Communication-efficient learning of deep networks from decentralized data. AISTATS 2017. arXiv:1602.05629
- SCAFFOLD: Karimireddy, S. P., Kale, S., Mohri, M., Reddi, S., Stich, S., & Suresh, A. T. (2020). SCAFFOLD: Stochastic controlled averaging for federated learning. ICML 2020. arXiv:1910.06378