Skip to main content

Federated LLMs: Prompting, Cascading, and Fine-Tuning at Scale

· 11 min read

Large Language Models have changed everything—including federated learning.

The old FL paradigm: Train a small model (~100M parameters) from scratch across devices.

The new FL paradigm: Adapt a massive pre-trained model (7B-70B parameters) using federated techniques.

But LLMs bring unique challenges to federated learning:

  • Size: 7B parameters = 28 GB (won't fit on most devices)
  • Compute: Full fine-tuning requires massive GPU memory
  • Inference cost: Running LLM inference on-device drains battery
  • Privacy: LLM memorization can leak training data

This post explores cutting-edge techniques for federated LLMs, from Virginia Smith's research group and beyond, showing how to make federated learning work in the foundation model era.

The LLM Federated Learning Stack

Three Approaches to Federated LLMs

1. Federated Fine-Tuning

  • Train adapter weights (LoRA, prefix tuning) on devices
  • Base LLM stays frozen
  • Upload only adapter updates (~MB instead of ~GB)

2. Federated Prompt Learning

  • Learn device-specific or global prompt embeddings
  • LLM stays completely frozen
  • Minimal communication (prompts are tiny)

3. Federated Inference Optimization

  • Optimize how LLMs run on federated devices
  • Cascading, caching, quantization
  • Focus on deployment, not training

Efficient LLM Cascades

The Inference Cost Problem

Running LLMs is expensive:

  • GPT-4: ~$0.03 per 1K tokens (input) + $0.06 per 1K tokens (output)
  • Llama-70B: Requires 140 GB VRAM, slow on CPU
  • On-device: Drains battery, only works for small models (7B max)

Solution: Model cascading—route queries to appropriate model size.

Agreement-Based Cascading (Smith et al., TMLR 2025)

Agreement-based cascading for efficient inference1 optimizes LLM cascades:

Key idea:

  1. Run small model first (fast, cheap)
  2. If confident → return answer
  3. If uncertain → escalate to larger model

Confidence via agreement:

  • Sample multiple outputs from small model
  • If they agree → high confidence
  • If they disagree → route to large model

Algorithm:

def agreement_cascade(query, small_model, large_model, agreement_threshold=0.8):
# Run small model multiple times
small_outputs = [small_model(query) for _ in range(5)]

# Compute agreement
agreement = compute_agreement(small_outputs)

if agreement > agreement_threshold:
return most_common(small_outputs) # Cheap answer
else:
return large_model(query) # Expensive but accurate

Result: 3-5× cost reduction with no accuracy loss.

import octomil

# Cascaded LLM inference in FL
client = octomil.OctomilClient(
project_id="federated-llm-cascade",

# Model cascade
models=[
"llama-3-1b", # Fast, cheap
"llama-3-8b", # Medium
"llama-3-70b" # Slow, expensive, accurate
],

# Cascading strategy
cascade_strategy="agreement",
agreement_samples=5,
agreement_threshold=0.8
)

# Queries automatically route to appropriate model
response = client.infer(query="Explain quantum computing")
# → Uses 1B model if confident, otherwise escalates

Semantic Agreement Enables Open-Ended Cascading (Smith et al., EMNLP 2025)

Problem: Agreement-based cascading requires exact match (works for classification, not generation).

Semantic agreement enables open-ended LLM cascades2 extends cascading to generative tasks:

Key innovation: Use semantic similarity instead of exact match:

def semantic_agreement(outputs, threshold=0.85):
# Embed all outputs
embeddings = [embed(output) for output in outputs]

# Compute pairwise similarities
similarities = [
cosine_similarity(emb1, emb2)
for emb1, emb2 in combinations(embeddings, 2)
]

# Agreement = average similarity
return mean(similarities)

Result: Cascading now works for summarization, question answering, creative writing.

# Semantic agreement for generative tasks
client = octomil.OctomilClient(
cascade_strategy="semantic-agreement",
similarity_model="all-minilm", # Lightweight embedding model
semantic_threshold=0.85
)

# Works for open-ended generation
summary = client.infer(
query="Summarize this 10-page document",
task_type="generation"
)
# → Small model generates summary, checks semantic consistency,
# escalates to large model if divergent

Extracting Parallelism from LLM Queries

PARALLELPROMPT (Smith et al., NeurIPS 2025)

Problem: LLM inference is sequential (auto-regressive generation), slow for batches of queries.

PARALLELPROMPT: Extracting parallelism from large language model queries3:

Key insight: Many real-world queries have shared structure:

  • Summarizing multiple documents (same prompt, different inputs)
  • Evaluating code snippets (same evaluation criteria)
  • Multi-document QA (ask same question on different contexts)

Technique: Batch prompting with parallel execution:

# Traditional: Sequential
for doc in documents:
summary = llm(f"Summarize: {doc}") # N sequential calls

# PARALLELPROMPT: Batched
combined_prompt = "Summarize each document below:\n" + "\n".join(
f"Document {i}: {doc}" for i, doc in enumerate(documents)
)
summary_batch = llm(combined_prompt) # 1 call, N outputs

Benefit: 5-10× faster inference for structured workloads.

Federated application: Batch queries across devices before sending to server-side LLM.

# Parallel prompt execution in FL
client = octomil.OctomilClient(
project_id="parallel-llm-fl",

# Automatic query batching
batching=True,
batch_size=32,
batch_timeout=100 # ms, wait up to 100ms to fill batch

# Parallelization strategy
parallel_strategy="prompt-batching" # PARALLELPROMPT-style
)

# Queries from devices automatically batched
for device_query in federated_queries:
result = client.infer(device_query)
# Batched with other queries for efficiency

Federated LLM Fine-Tuning with Sparse Gradients

GRASS: Structured Sparse Gradients (Smith et al., EMNLP 2024)

Problem: LLM fine-tuning requires storing gradients for billions of parameters.

Memory requirement:

  • Model: 7B parameters × 4 bytes = 28 GB
  • Gradients: 7B parameters × 4 bytes = 28 GB
  • Optimizer state (Adam): 7B parameters × 8 bytes = 56 GB
  • Total: 112 GB (doesn't fit on consumer GPUs)

GRASS: Compute efficient low-memory LLM training with structured sparse gradients4:

Key technique: Maintain sparse gradients with structure:

  1. Only compute/store top-k% of gradients (by magnitude)
  2. Use structured sparsity (e.g., entire attention heads, not random weights)
  3. Error feedback compensates for dropped gradients

Memory savings: 5-10× reduction (from 112 GB → 15-20 GB)

# GRASS-based federated LLM fine-tuning
client = octomil.OctomilClient(
project_id="federated-llama-finetuning",

# Sparse gradient training
sparse_gradients=True,
sparsity_structure="structured", # Block-wise sparsity
target_sparsity=0.9, # 90% sparse (10× memory reduction)

# Error feedback for convergence
error_feedback=True
)

# Fine-tune 7B model on-device (only need 8 GB VRAM)
client.train(
model="llama-3-7b",
adapter="lora", # LoRA on top of sparse gradients
local_epochs=1
)

Worst-Group Robustness for Foundation Models

Prompting as Double-Edged Sword (Smith et al., ICML 2024)

Problem: Prompts improve average accuracy but can hurt worst-group performance.

Prompting is a double-edged sword: Improving worst-group robustness of foundation models5:

Key finding: Naive prompting improves majority groups, degrades minority groups.

Example:

  • Vanilla model: 80% avg accuracy, 70% worst-group accuracy
  • With prompt "You are an expert": 85% avg accuracy, 60% worst-group accuracy (worse!)
  • With group-aware prompt: 84% avg accuracy, 75% worst-group accuracy (better!)

Solution: Group-aware prompting:

# Bad: Generic prompt
prompt = "You are an expert medical diagnostician. Diagnose this patient."

# Good: Group-aware prompt
prompt = """
You are a medical diagnostician. Consider diverse patient populations,
including those from underrepresented groups. Provide unbiased diagnosis.
Diagnose this patient:
"""

Federated application: Learn group-aware prompts from federated data.

# Federated group-aware prompt learning
client = octomil.OctomilClient(
project_id="fair-llm-fl",

# Group-aware prompting
prompting="group-aware",
groups=["group_a", "group_b", "group_c"],

# Fairness objective
fairness_constraint="bounded-group-loss",
max_group_disparity=0.1 # Max 10% accuracy gap across groups
)

# Learn prompts that are fair across groups
prompts = client.learn_prompts(
base_model="llama-3-70b",
rounds=20
)

LLM Unlearning in Federated Settings

Exact Unlearning via Model Merging (Smith et al., SaTML 2026)

Problem: User requests data deletion (GDPR "right to be forgotten"). How to remove their contribution from LLM?

Exact unlearning of finetuning data via model merging at scale6:

Traditional approach: Retrain from scratch without user's data (expensive: $100K+ for LLMs).

Model merging approach:

  1. Train model with user: M_with
  2. Train model without user: M_without
  3. Unlearned model: M_current - (M_with - M_without)

Benefit: Approximate unlearning in hours instead of weeks.

# Federated LLM unlearning
client = octomil.OctomilClient(
project_id="llm-unlearning-fl",
unlearning_enabled=True
)

# User requests deletion
client.request_unlearning(
device_id="device-12345",
method="model-merging",

# Verification
verify_unlearning=True # Check via membership inference attacks
)

# Octomil:
# 1. Identifies affected model checkpoints
# 2. Computes model-with vs model-without
# 3. Merges to remove user's influence
# 4. Validates via MI attacks

LLM Unlearning Benchmarks are Weak (Smith et al., SaTML 2025)

Warning: LLM unlearning benchmarks are weak measures of progress7.

Key findings:

  • Existing benchmarks can be "passed" without actual unlearning
  • Adversaries can still extract supposedly-unlearned data
  • Need stronger evaluation metrics

Implication: Federated LLM unlearning is an open problem. Current methods are not foolproof.

# Octomil's conservative unlearning approach
client = octomil.OctomilClient(
unlearning_enabled=True,

# Strong verification
unlearning_verification="adversarial", # Use adversarial extraction attacks
verification_budget=1000, # Query budget for verification

# Fallback: Retrain if merging fails verification
fallback_to_retrain=True
)

Federated Prompt Tuning

Prompting vs Fine-Tuning Tradeoff

Fine-tuning:

  • Pros: High accuracy, full adaptation
  • Cons: Expensive, requires gradient computation

Prompt tuning:

  • Pros: Cheap, no gradient computation, tiny updates
  • Cons: Limited expressiveness

Federated prompt tuning: Learn shared prompt embeddings across devices.

# Federated prompt learning
client = octomil.OctomilClient(
project_id="prompt-tuning-fl",

# Prompt configuration
prompt_tuning=True,
prompt_length=20, # 20 token prompt
prompt_initialization="random", # or "task-specific"

# Federated aggregation of prompts
prompt_aggregation="fedavg" # Average prompt embeddings
)

# Train prompts (base LLM frozen)
prompts = client.train_prompts(
base_model="llama-3-70b-frozen",
rounds=50
)

# Prompt size: 20 tokens × 4096 dim × 4 bytes = 320 KB
# vs LoRA: 8M parameters × 4 bytes = 32 MB
# → 100× smaller updates

Personalized Prompts

Challenge: Different devices need different prompts.

Solution: Learn device-specific prompt prefixes + shared prompt suffix.

# Personalized federated prompt learning
client = octomil.OctomilClient(
prompt_tuning=True,
personalized_prompts=True,

# Architecture
shared_prompt_length=10, # Global shared prompt
personal_prompt_length=10 # Device-specific prompt
)

# Each device:
# - Downloads shared prompt (10 tokens)
# - Maintains personal prompt (10 tokens, never uploaded)
# - Concatenates for inference: [personal || shared || task input]

Production Federated LLM Architecture

Octomil's LLM-FL Stack

import octomil

# Comprehensive federated LLM setup
client = octomil.OctomilClient(
project_id="production-llm-fl",

# Model configuration
base_model="llama-3-70b", # Server-side
device_model="llama-3-1b", # On-device cache

# Training strategy
training_method="lora", # or "prompt-tuning", "full-finetuning"
lora_rank=16,
lora_alpha=32,

# Sparse gradients (memory efficiency)
sparse_gradients=True,
sparsity=0.9, # GRASS-style

# Cascading (inference efficiency)
cascade_enabled=True,
cascade_strategy="semantic-agreement",

# Batching (throughput)
query_batching=True, # PARALLELPROMPT-style
batch_size=32,

# Fairness
group_aware_prompts=True,
fairness_constraint="bounded-group",

# Privacy
unlearning_enabled=True,
unlearning_method="model-merging"
)

# Federated LLM workflow
client.train(
rounds=100,
local_epochs=1
)

# Inference with cascading
response = client.infer(query="Complex question requiring reasoning")
# → Small model tries first, escalates if needed

# Unlearning on request
client.request_unlearning(device_id="device-123")

When to Use Federated LLMs

Use CaseApproachWhy
Personal assistantPrompt tuningTiny updates, preserves privacy
Domain adaptationLoRA fine-tuningBalance accuracy & efficiency
Sensitive dataOn-device onlyNever send data to cloud
Cost optimizationCascadingRoute queries to appropriate model size
Fairness-criticalGroup-aware promptingEnsure equitable performance
Regulatory complianceUnlearning-enabledGDPR, CCPA compliance

Real-World Performance

Case Study 1: Medical LLM Federation

Setup: 50 hospitals, sensitive patient data

Approach:

  • Base model: Llama-3-70B (medical domain)
  • Method: Federated LoRA fine-tuning with GRASS
  • Privacy: On-device only, differential privacy

Results:

  • Accuracy: 92% (vs. 87% without fine-tuning)
  • Memory: 12 GB per device (vs. 140 GB for full fine-tuning)
  • Communication: 32 MB per round (vs. 280 GB for full model)
  • Unlearning: 2 hours per request (vs. 3 weeks retraining)

Case Study 2: Multilingual Keyboard Prediction

Setup: 10M devices, 100+ languages

Approach:

  • Cascade: 1B model on-device → 8B model server-side
  • Prompt tuning: Language-specific prompts
  • Batching: PARALLELPROMPT for throughput

Results:

  • Latency: 50ms p50 (vs. 200ms without cascading)
  • Cost: $0.0001 per query (vs. $0.001 without cascading)
  • Coverage: 95% of queries answered by 1B model
  • Accuracy: 91% (equivalent to always using 8B model)

Research Frontiers

Open problems:

  1. Federated RLHF: Reinforcement learning from human feedback across devices
  2. Mixture of experts: Federated training of MoE architectures
  3. Cross-modal FL: Training LLMs + vision models jointly
  4. Long-context FL: Efficient federated training for 100K+ context windows

Getting Started

pip install octomil

# Initialize federated LLM project
octomil init llm-project \
--base-model llama-3-70b \
--training-method lora

# Train federated LLM
octomil train-llm \
--cascade semantic-agreement \
--sparse-gradients \
--fairness group-aware

# Deploy with cascading
octomil deploy-llm \
--cascade-models 1b,8b,70b \
--batch-queries

See our Federated LLM Guide for complete tutorials.


References

Footnotes

  1. Dennis, D., Kolawole, S., Talwalkar, A., & Smith, V. (2025). Agreement-based cascading for efficient inference. Transactions on Machine Learning Research (TMLR). OpenReview

  2. Soiffer, D., Kolawole, S., & Smith, V. (2025). Semantic agreement enables efficient open-ended LLM cascades. Empirical Methods in Natural Language Processing (EMNLP) Industry Track. arXiv

  3. Kolawole, S., Santhanam, K., Smith, V., & Thaker, P. (2025). PARALLELPROMPT: Extracting parallelism from large language model queries. Neural Information Processing Systems (NeurIPS) Datasets and Benchmarks Track. arXiv

  4. Muhamed, A., Li, O., Woodruff, D., Diab, M., & Smith, V. (2024). GRASS: Compute efficient low-memory LLM training with structured sparse gradients. Empirical Methods in Natural Language Processing (EMNLP). arXiv

  5. Setlur, A., Garg, S., Smith, V., & Levine, S. (2024). Prompting is a double-edged sword: Improving worst-group robustness of foundation models. International Conference on Machine Learning (ICML). arXiv

  6. Kuo, K., Setlur, A., Srinivas, K., Raghunathan, A., & Smith, V. (2026). Exact unlearning of finetuning data via model merging at scale. Conference on Secure and Trustworthy Machine Learning (SaTML). arXiv:2504.04626

  7. Thaker, P., Hu, S., Kale, N., Maurya, Y., Wu, Z. S., & Smith, V. (2025). LLM unlearning benchmarks are weak measures of progress. Conference on Secure and Trustworthy Machine Learning (SaTML). arXiv