Second-Order Federated Learning: When Newton Beats SGD
Federated learning loves first-order methods. FedAvg, SCAFFOLD, FedProx—they all use gradients (first derivatives). They're simple, memory-efficient, and work reasonably well.
But here's a provocative question: What if we could converge in 10 rounds instead of 100?
Second-order methods—using curvature information (Hessians, second derivatives)—can achieve dramatically faster convergence by taking smarter steps. The classic tradeoff: Fewer iterations, but more computation per iteration.
In federated learning, this tradeoff flips in our favor: Communication is expensive, computation is cheap (especially with modern accelerators). Trading computation for communication is exactly what we want.
This post explores second-order methods for FL and shows when they're worth the extra compute.
Why Second-Order Methods?
The Gradient Descent Problem
First-order (gradient descent):
Limitation: Uses only gradient direction, ignores curvature of loss landscape.
Consequence: Takes many small steps, especially in ill-conditioned problems.
Example: Imagine a valley with steep sides and gentle bottom.
- Gradient descent: Bounces back and forth on steep sides, slow progress along valley
- Needs 1000s of iterations
The Newton's Method Solution
Second-order (Newton's method):
where is the Hessian (matrix of second derivatives).
Advantage: Adapts step to local curvature → quadratic convergence near optimum.
Result: Can converge in 10-20 iterations (vs. 1000s for gradient descent).
The FL Tradeoff
Centralized ML: Second-order methods are expensive
- Computing Hessian: O(d²) memory and O(d³) time for d parameters
- Not worth it when gradient computation is cheap
Federated Learning: Tradeoff reverses
- Communication is 100-1000× more expensive than computation
- If second-order methods reduce rounds from 1000 → 100, we win massively
- Modern devices have powerful accelerators (GPU, TPU, Neural Engine)
Stochastic Proximal Point Methods (SPPM)
The Proximal Point Framework (Richtárik et al.)
Key insight: Instead of directly minimizing f(w), solve a sequence of regularized subproblems where each step minimizes the objective plus a proximity term to the previous iterate.
Why this helps: The regularization term makes each subproblem strongly convex (easier to solve) even if original f is non-convex.
A Unified Theory of SPPM (Richtárik et al., 2024)
A unified theory of stochastic proximal point methods without smoothness1 provides first rigorous analysis of SPPM for non-smooth, non-convex FL.
Key contributions:
- Works without smoothness assumptions (standard theory requires Lipschitz gradients)
- Unified framework covering SPPM, SPPM-LC, SPPM-NS, SPPM-AS, SPPM*, SPPM-GC
- Convergence for variance-reduced variants (L-SVRP, Point-SAGA)
Algorithm variants:
- SPPM-AS: Adaptive stepsizes (no tuning needed)
- SPPM-LC: Local curvature adaptation
- SPPM-NS: Non-smooth objectives
- SPPM-GC: Gradient clipping for robustness
import octomil
# Proximal point method in FL
client = octomil.OctomilClient(
project_id="proximal-fl",
algorithm="sppm",
# Proximal parameters
proximal_penalty=0.1, # Regularization strength
subproblem_solver="sgd", # How to solve each subproblem
# Adaptive variant
adaptive_stepsize=True # SPPM-AS
)
client.train(
model=my_model,
rounds=50 # Converges faster than first-order methods
)
SPPM Under Expected Similarity (Richtárik, 2024)
Stochastic proximal point methods for monotone inclusions under expected similarity2 extends SPPM to variational inequalities:
Setting: Beyond optimization, solve equilibrium problems:
Applications: GANs, multi-agent learning, game-theoretic FL
# SPPM for equilibrium problems
client = octomil.OctomilClient(
algorithm="sppm",
problem_type="variational-inequality", # Beyond standard optimization
# Similarity parameter (problem structure)
expected_similarity=0.9 # High similarity → fast convergence
)
FedNL: Federated Newton Learning
Unlocking FedNL (Richtárik, 2024)
FedNL is the premier second-order method for federated learning, and Unlocking FedNL3 provides production-ready implementation.
Key innovation: Efficient second-order updates without ever forming full Hessian.
Algorithm:
- Each device computes local Hessian-vector product: H_i·v
- Server aggregates: H·v = (1/n) Σ H_i·v
- Use conjugate gradient to solve H·Δw = -∇f iteratively
- Update: w ← w + Δw
Why it's practical:
- Never store full Hessian (O(d²) memory)
- Only compute Hessian-vector products (O(d) operations via automatic differentiation)
- Self-contained, compute-optimized implementation
# FedNL: Second-order federated learning
client = octomil.OctomilClient(
project_id="fednl-project",
algorithm="fednl",
# Hessian approximation
hessian_approx="full", # or "diagonal", "kfac"
# Subproblem solver
subproblem_solver="conjugate-gradient",
cg_iterations=20,
# Adaptive settings
adaptive_regularization=True # Adjust regularization per round
)
# Typically converges in 20-50 rounds (vs. 200-500 for FedAvg)
client.train(model=my_model, rounds=50)
Variants:
- FedNL-LS: Line search for optimal step size
- FedNL-PP: Proximal point formulation
Performance Characteristics
When FedNL wins:
- Communication-dominated workloads (slow networks)
- Well-structured problems (clear curvature)
- Devices with good compute (GPUs, modern phones)
When FedNL loses:
- Compute-dominated (fast networks, weak devices)
- Very non-convex problems (Hessian misleading)
- Extreme scale (Hessian-vector products expensive)
ProxSkip: Probabilistic Gradient Compression
ProxSkip Algorithm (Richtárik et al.)
ProxSkip4 combines proximal methods with probabilistic communication:
Key idea: Each device probabilistically skips communication rounds but maintains convergence via proximal framework.
Algorithm:
# ProxSkip client
def proxskip_client(model, data, skip_prob):
# Local proximal update
local_model = proximal_update(model, data)
# Probabilistically skip communication
if random() < skip_prob:
return None # Skip this round
# Otherwise, send update
return local_model - model
Benefits:
- Reduces communication by factor of (1 - skip_prob)
- Maintains convergence (proximal framework guarantees progress)
- Adaptive: Can tune skip_prob per device based on network quality
# ProxSkip in Octomil
client = octomil.OctomilClient(
algorithm="proxskip",
# Skip probability
skip_probability=0.7, # Skip 70% of rounds (3.3× communication reduction)
# Proximal penalty (ensures convergence despite skipping)
proximal_penalty=0.1
)
# Result: 3× fewer communication rounds with same convergence
Local Curvature Descent
Squeezing More Curvature from Gradient Descent (Richtárik, NeurIPS 2025)
Local Curvature Descent5 extracts second-order information from gradients alone (no Hessian computation).
Key insight: Approximate local curvature from gradient history by computing the finite difference quotient of consecutive gradients divided by the change in parameters.
Algorithms:
- LCD1: First-order curvature approximation
- LCD2: Second-order curvature approximation
- LCD3: Adaptive curvature with momentum
# Local curvature descent
client = octomil.OctomilClient(
algorithm="lcd3", # Most advanced variant
# Curvature estimation
curvature_window=10, # Use last 10 gradients for estimation
# No Hessian computation needed!
hessian_free=True
)
# Gets second-order-like convergence with first-order cost
Benefit: Second-order acceleration without second-order computation.
Riemannian Optimization for Constrained FL
Streamlining in the Riemannian Realm (Richtárik, 2024)
Many FL problems have constraints (e.g., model weights on unit sphere, low-rank matrices).
Riemannian optimization naturally handles constraints by optimizing on manifolds.
Streamlining in the Riemannian realm6 provides efficient Riemannian methods:
- R-LSVRG: Riemannian loopless variance reduction
- R-PAGE: Riemannian PAGE
- R-MARINA: Riemannian MARINA
# Riemannian FL for constrained problems
client = octomil.OctomilClient(
algorithm="r-marina",
# Manifold constraint
manifold="sphere", # or "stiefel", "grassmann", "spd"
# Riemannian metrics
metric="canonical" # Natural metric for manifold
)
# Example: Low-rank matrix factorization FL
client.train(
model=matrix_factorization_model,
constraint="low-rank",
rank=10
)
Non-Euclidean Proximal Point Method
A Blueprint for Geometry-Aware Optimization (Richtárik, 2026)
Non-Euclidean proximal point method7 generalizes proximal methods to arbitrary geometries:
Key idea: Replace Euclidean distance with Bregman divergence :
Benefits:
- Natural for non-Euclidean domains (probabilities, positive definite matrices)
- Exploits problem geometry for faster convergence
- Unified framework: Euclidean, Riemannian, mirror descent
# Non-Euclidean proximal method
client = octomil.OctomilClient(
algorithm="bpm", # Ball-proximal (broximal) point method
# Geometry
geometry="kullback-leibler", # For probability distributions
# Or other geometries
# geometry="euclidean" # Standard
# geometry="mahalanobis" # Weighted Euclidean
)
Advanced Curvature Exploitation
LMO-Based Momentum Methods (Richtárik, 2026)
For very large models, even computing gradients is expensive. Linear Minimization Oracle (LMO) methods replace gradient with simpler linear optimization.
Better LMO-based momentum methods with second-order information8:
Algorithm (LMO-SOM):
- Instead of gradient: Solve linear minimization over constraint set C
- Add second-order momentum for acceleration
- Result: Convergence with only linear optimization oracle
# LMO-based second-order methods
client = octomil.OctomilClient(
algorithm="lmo-som",
# Constraint set
constraint_set="l1-ball", # Sparse models
# Second-order momentum
momentum_type="second-order",
momentum=0.99
)
When to Use Second-Order Methods
| Method | Best For | Communication Savings | Compute Overhead |
|---|---|---|---|
| FedNL | Well-conditioned, slow networks | 5-10× | 3-5× per round |
| SPPM | Non-smooth objectives | 3-5× | 2× per round |
| ProxSkip | Unreliable communication | 2-4× | Minimal |
| LCD | Want second-order without Hessian | 2-3× | Minimal |
| Riemannian | Constrained problems | 3-7× | 2-4× per round |
Rule of thumb: If communication is 10×+ more expensive than computation, second-order methods are worth it.
Octomil's Second-Order Framework
import octomil
# Automatic second-order method selection
client = octomil.OctomilClient(
project_id="second-order-fl",
# High-level: Let Octomil choose
optimization="second-order", # vs. "first-order"
# Octomil profiles:
# - Network latency
# - Device compute capability
# - Problem conditioning
# Then selects optimal method (FedNL, SPPM, LCD, etc.)
# Or explicitly choose
# algorithm="fednl", # Full second-order
# algorithm="lcd3", # Second-order-like, first-order cost
# algorithm="proxskip", # Probabilistic communication
# Automatic tuning
adaptive_hyperparams=True
)
# Train with second-order methods
model = client.train(
model=my_model,
rounds=50 # Converges in far fewer rounds
)
# Compare against first-order baseline
stats = client.compare_with_baseline("fedavg")
print(f"Communication reduction: {stats.comm_reduction}×")
print(f"Compute increase: {stats.compute_increase}×")
print(f"Wall-clock speedup: {stats.wallclock_speedup}×")
Real-World Performance
Case Study 1: Image Classification (ResNet-50)
Setup: 100 devices, 3G/4G mixed network
Results:
- FedAvg: 400 rounds, 8 hours wall-clock
- FedNL: 60 rounds, 3 hours wall-clock (2.7× faster)
- Compute overhead: 4× per round, but rounds complete in parallel
Key insight: Network latency dominated, so 4× compute was acceptable.
Case Study 2: Language Model Fine-Tuning
Setup: 1,000 devices, LLaMA-7B fine-tuning
Results:
- FedAvg + LoRA: 200 rounds
- FedNL + LoRA: 80 rounds (2.5× fewer)
- Communication: 60% reduction in total bytes transferred
Technique: FedNL with low-rank Hessian approximation (diagonal per LoRA adapter).
Case Study 3: Sparse Model Training
Setup: L1-regularized model, 50 hospitals
Results:
- Proximal SGD: 300 rounds
- SPPM: 100 rounds (3× faster)
- Benefit: SPPM naturally handles non-smooth L1 regularization
Challenges and Solutions
Challenge 1: Hessian Computation Cost
Problem: Full Hessian is O(d²) memory, O(d³) computation.
Solutions:
- Diagonal Hessian approximation (O(d))
- KFAC (Kronecker-factored) approximation
- Hessian-free methods (LCD)
Challenge 2: Ill-Conditioned Problems
Problem: Hessian may be nearly singular or have negative eigenvalues.
Solutions:
- Adaptive regularization (trust region methods)
- Gauss-Newton approximation (always positive semi-definite)
- Cubic regularization
Challenge 3: Heterogeneous Curvature
Problem: Different devices may have very different local curvature.
Solutions:
- Personalized Hessian approximations
- Robust aggregation (median of Hessians)
- Adaptive per-device regularization
Getting Started
pip install octomil
# Initialize with second-order optimization
octomil init my-project --optimization second-order
# Train with FedNL
octomil train \
--algorithm fednl \
--hessian-approx diagonal \
--rounds 50
# Or automatic selection
octomil train \
--optimization auto \
--prefer-communication-efficiency
See our Advanced FL Configuration guide for detailed tutorials and benchmarks.