Skip to main content

Federated LLMs

Federated LLM workflows combine privacy with continuous improvement using on-device prompt tuning and lightweight adaptation.

Why federated learning for LLMs

Large language models contain valuable general knowledge, but their real value in production comes from adaptation to specific domains, users, or tasks. Traditional fine-tuning requires centralizing user data on a server -- writing prompts, completions, corrections, and interaction patterns to a training dataset. For applications involving personal data (messaging, health queries, legal documents, financial advice), this centralization is a non-starter under GDPR, HIPAA, and similar regulations.

Federated learning offers a path: keep user interaction data on-device, fine-tune the model locally, and aggregate only the model updates. The user's data never leaves their device. The global model improves from everyone's usage patterns without any single user's data being exposed.

The challenge is that LLMs are orders of magnitude larger than traditional FL models, and the standard federated learning playbook was designed for models that fit comfortably in device memory. Adapting FL to LLMs requires rethinking almost every assumption about model size, communication cost, and on-device compute.

Key challenges

Model size and device memory

A typical FL model (image classifier, next-word predictor) might have 1-10M parameters. A small LLM starts at 1B parameters. Even a quantized 1B-parameter model requires ~500MB-1GB of RAM just for inference, and fine-tuning requires 2-4x the model size in memory for optimizer states and activations.

Model sizeMemory (inference)Memory (fine-tuning)Feasible on mobile?
100M params~200MB~800MBYes
500M params~1GB~4GBMarginal (high-end phones)
1B params~2GB~8GBNo (full fine-tuning)
7B params~14GB~56GBNo

Full fine-tuning of LLMs on mobile devices is not feasible. Parameter-efficient methods are required.

Communication cost

Transmitting a full model update for a 1B-parameter model would require sending ~4GB per round per client (at 32-bit precision). Even with quantization to 8-bit, that is 1GB. On mobile networks, this is prohibitive in both time and bandwidth cost.

Convergence with heterogeneous text data

Text data is inherently more heterogeneous than image data. Each user has a unique vocabulary, writing style, and topic distribution. This extreme non-IID setting amplifies client drift and makes federated convergence harder than in traditional FL settings.

Parameter-efficient federated fine-tuning

The solution to the size and communication problems is to train only a small subset of parameters while keeping the base model frozen. Several approaches are viable for federated LLM workflows.

LoRA (Low-Rank Adaptation)

LoRA freezes the pretrained model weights and injects small trainable low-rank matrices into each transformer layer. Instead of fine-tuning a weight matrix W (of shape d x d), LoRA trains two smaller matrices A (shape d x r) and B (shape r x d) where r << d. The effective weight becomes W + A * B.

Why LoRA is ideal for federated LLMs:

  • The trainable parameter count is typically 0.1-1% of the full model. A 1B-parameter model with rank-8 LoRA adapters might have only 1-5M trainable parameters.
  • Communication cost drops proportionally: clients send adapter weights (a few MB), not the full model.
  • On-device memory is reduced because only the adapters require gradient computation and optimizer states. The base model weights are loaded in inference mode.
  • Aggregation is straightforward: the server averages LoRA adapter weights using standard FL aggregation strategies.

QLoRA (Quantized LoRA)

QLoRA combines LoRA with aggressive base model quantization (4-bit NormalFloat). The base model is loaded in 4-bit precision, and LoRA adapters are trained in full precision on top.

Benefits for federated deployment:

  • A 7B-parameter model fits in ~4GB of RAM with 4-bit quantization, making inference feasible on high-end mobile devices.
  • Only the LoRA adapters (full precision, but small) are trained and transmitted.
  • Quality is surprisingly close to full fine-tuning, especially for domain adaptation tasks.

Trade-off: 4-bit quantization introduces a small accuracy penalty on the base model, and not all hardware supports efficient 4-bit inference.

Prompt tuning and prefix tuning

Instead of modifying model weights, prompt tuning prepends a set of trainable "soft prompt" tokens to the input. These tokens are optimized to steer the model's behavior for a specific task.

  • Parameter count: Extremely small (a few thousand parameters per task).
  • Communication cost: Negligible -- soft prompts are a few KB.
  • Limitation: Less expressive than LoRA. Works well for task adaptation but poorly for domain-specific knowledge injection.

In federated settings, prompt tuning is useful for lightweight personalization (adjusting style, formality, domain terminology) but not for substantial model improvement.

Typical patterns

  • Prompt tuning for domain adaptation.
  • LoRA-style adapter updates.
  • Cascading model strategies for device tiers.

Cascading model strategies

Not all devices in a fleet have the same compute capability. A cascading strategy deploys different model sizes to different device tiers:

  • Tier 1 (high-end phones, tablets): Full LLM with LoRA adapters. On-device fine-tuning with 4-bit base model.
  • Tier 2 (mid-range devices): Distilled smaller model. On-device inference only, no fine-tuning. Benefits from federated improvements via model updates from Tier 1.
  • Tier 3 (low-end / IoT): Prompt-tuned compact model or API fallback.

Use Device Groups to segment your fleet by capability and assign appropriate model variants.

Octomil LLM configuration

Federated LoRA fine-tuning

from octomil import Federation, ModelRegistry

registry = ModelRegistry(api_key="edg_...")

# Register a base model with LoRA configuration
registry.upload_version_from_path(
model_id="domain-llm-v1",
file_path="./llama-3-1b.onnx",
version="1.0.0",
formats=["onnx", "coreml", "tflite"],
)

# Create a federated training job for the adapters
federation = Federation(api_key="edg_...", name="domain-llm-training")

result = federation.train(
model="domain-llm-v1",
algorithm="fedavg", # Standard aggregation works for adapters
rounds=100,
min_updates=10,
)

Configure LoRA adapter settings and strategy parameters via the dashboard (Training > Strategy Config) or the REST API:

Configure LoRA adapter:

curl -X PUT https://api.octomil.com/api/v1/federations/domain-llm-training/adapter \
-H "Authorization: Bearer edg_..." \
-H "Content-Type: application/json" \
-d '{
"type": "lora",
"rank": 8,
"alpha": 16,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
"dropout": 0.05
}'

Configure strategy:

curl -X PUT https://api.octomil.com/api/v1/federations/domain-llm-training/strategy \
-H "Authorization: Bearer edg_..." \
-H "Content-Type: application/json" \
-d '{
"algorithm": "fedavg",
"learning_rate": 1e-4,
"local_epochs": 3,
"max_grad_norm": 1.0,
"weight_decay": 0.01,
"round_config": {
"timeout_seconds": 900,
"min_samples_per_client": 50
},
"device_filter": {
"min_ram_mb": 6144,
"min_storage_mb": 2048
}
}'

Server-side API

# POST /api/v1/training-sessions
{
"model_id": "model-uuid",
"strategy": "fedavg",
"adapter_config": {
"type": "lora",
"rank": 8,
"alpha": 16,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
},
"strategy_params": {
"learning_rate": 1e-4,
"local_epochs": 3,
"max_grad_norm": 1.0
},
"round_config": {
"min_clients": 10,
"max_rounds": 100,
"timeout_seconds": 900
},
"device_filter": {
"min_ram_mb": 6144,
"min_storage_mb": 2048
}
}

Monitoring LLM training

LLM federated training requires closer monitoring than standard FL due to the risk of catastrophic forgetting (the adapted model loses base capabilities) and the higher per-round cost.

Track these metrics in the Monitoring Dashboard:

  • Adapter weight norms: Should grow slowly and plateau. Explosive growth indicates instability.
  • Perplexity on held-out data: The primary quality metric. Should decrease smoothly.
  • Base model capability retention: Run periodic evaluations on general benchmarks to detect catastrophic forgetting.
  • Per-device training time: LLM fine-tuning is compute-intensive. Monitor for devices that consistently time out.

Device memory budgets

When planning a federated LLM deployment, calculate the per-device memory budget explicitly:

Total device RAM needed =
Base model (quantized) # e.g., 1GB for 1B params at 4-bit
+ LoRA adapters (full precision) # e.g., 10MB for rank-8 adapters
+ Optimizer states (for adapters) # e.g., 20MB (Adam: 2x adapter size)
+ Activation memory (per batch) # e.g., 200MB (depends on sequence length)
+ OS and app overhead # e.g., 500MB
-----------------------------------
~ 1.7GB total

For a device with 6GB RAM, this leaves ~4.3GB for other applications. Set min_ram_mb in the device filter to ensure only capable devices are selected for training rounds.

Split learning alternative

For models too large for on-device fine-tuning (even with LoRA + quantization), split learning divides the model at a chosen layer: bottom layers run on-device, top layers on the server. The device sends intermediate activations (not raw data) to the server, which completes the forward pass and returns gradients.

Trade-offs: Split learning requires per-batch network round trips (needs low-latency connectivity) and intermediate activations can leak information. Apply differential privacy noise at the split point to mitigate this. Use split learning only when LoRA is infeasible and network conditions allow synchronous communication.

Best practices

  1. Start with LoRA rank 4-8. Higher ranks increase expressiveness but also communication cost and memory. Rank 8 is sufficient for most domain adaptation tasks. Only increase if you see clear underfitting.

  2. Always clip gradients. LLM fine-tuning is prone to gradient spikes, especially in early rounds. Set max_grad_norm to 1.0 as a default and only increase if training is too slow.

  3. Use FedAvg for adapter aggregation. LoRA adapters are small enough that the communication overhead is manageable, and the adapters' low dimensionality reduces the impact of client drift. SCAFFOLD or robust aggregation may be warranted for extreme heterogeneity but adds complexity that is usually unnecessary.

  4. Filter devices by capability. Not every device in your fleet can run LLM fine-tuning. Use device groups and hardware filters aggressively. It is better to train on 10 capable devices than to include 50 devices where 40 will time out.

  5. Evaluate for catastrophic forgetting. After every N rounds, evaluate the adapted model on general-purpose benchmarks (not just the target task). If general capability degrades, reduce the learning rate or add a regularization term.

  6. Deploy adapters, not full models. The base model is distributed once (or pre-installed). Only the LoRA adapters are updated per round. This reduces deployment bandwidth by 100-1000x. Use the Model Catalog to version base models and adapters independently.

  7. Use canary rollouts for adapter updates. LLM behavior changes can be subtle and hard to catch with automated metrics. Deploy adapter updates to a canary group first using Model Rollouts and monitor user-facing quality metrics before full rollout.

Implementation path

  1. Validate foundation setup from Quickstart.
  2. Use round orchestration from Training Rounds.
  3. Protect production with Model Rollouts.

Further reading