Federated LLMs
Federated LLM workflows combine privacy with continuous improvement using on-device prompt tuning and lightweight adaptation.
Why federated learning for LLMs
Large language models contain valuable general knowledge, but their real value in production comes from adaptation to specific domains, users, or tasks. Traditional fine-tuning requires centralizing user data on a server -- writing prompts, completions, corrections, and interaction patterns to a training dataset. For applications involving personal data (messaging, health queries, legal documents, financial advice), this centralization is a non-starter under GDPR, HIPAA, and similar regulations.
Federated learning offers a path: keep user interaction data on-device, fine-tune the model locally, and aggregate only the model updates. The user's data never leaves their device. The global model improves from everyone's usage patterns without any single user's data being exposed.
The challenge is that LLMs are orders of magnitude larger than traditional FL models, and the standard federated learning playbook was designed for models that fit comfortably in device memory. Adapting FL to LLMs requires rethinking almost every assumption about model size, communication cost, and on-device compute.
Key challenges
Model size and device memory
A typical FL model (image classifier, next-word predictor) might have 1-10M parameters. A small LLM starts at 1B parameters. Even a quantized 1B-parameter model requires ~500MB-1GB of RAM just for inference, and fine-tuning requires 2-4x the model size in memory for optimizer states and activations.
| Model size | Memory (inference) | Memory (fine-tuning) | Feasible on mobile? |
|---|---|---|---|
| 100M params | ~200MB | ~800MB | Yes |
| 500M params | ~1GB | ~4GB | Marginal (high-end phones) |
| 1B params | ~2GB | ~8GB | No (full fine-tuning) |
| 7B params | ~14GB | ~56GB | No |
Full fine-tuning of LLMs on mobile devices is not feasible. Parameter-efficient methods are required.
Communication cost
Transmitting a full model update for a 1B-parameter model would require sending ~4GB per round per client (at 32-bit precision). Even with quantization to 8-bit, that is 1GB. On mobile networks, this is prohibitive in both time and bandwidth cost.
Convergence with heterogeneous text data
Text data is inherently more heterogeneous than image data. Each user has a unique vocabulary, writing style, and topic distribution. This extreme non-IID setting amplifies client drift and makes federated convergence harder than in traditional FL settings.
Parameter-efficient federated fine-tuning
The solution to the size and communication problems is to train only a small subset of parameters while keeping the base model frozen. Several approaches are viable for federated LLM workflows.
LoRA (Low-Rank Adaptation)
LoRA freezes the pretrained model weights and injects small trainable low-rank matrices into each transformer layer. Instead of fine-tuning a weight matrix W (of shape d x d), LoRA trains two smaller matrices A (shape d x r) and B (shape r x d) where r << d. The effective weight becomes W + A * B.
Why LoRA is ideal for federated LLMs:
- The trainable parameter count is typically 0.1-1% of the full model. A 1B-parameter model with rank-8 LoRA adapters might have only 1-5M trainable parameters.
- Communication cost drops proportionally: clients send adapter weights (a few MB), not the full model.
- On-device memory is reduced because only the adapters require gradient computation and optimizer states. The base model weights are loaded in inference mode.
- Aggregation is straightforward: the server averages LoRA adapter weights using standard FL aggregation strategies.
QLoRA (Quantized LoRA)
QLoRA combines LoRA with aggressive base model quantization (4-bit NormalFloat). The base model is loaded in 4-bit precision, and LoRA adapters are trained in full precision on top.
Benefits for federated deployment:
- A 7B-parameter model fits in ~4GB of RAM with 4-bit quantization, making inference feasible on high-end mobile devices.
- Only the LoRA adapters (full precision, but small) are trained and transmitted.
- Quality is surprisingly close to full fine-tuning, especially for domain adaptation tasks.
Trade-off: 4-bit quantization introduces a small accuracy penalty on the base model, and not all hardware supports efficient 4-bit inference.
Prompt tuning and prefix tuning
Instead of modifying model weights, prompt tuning prepends a set of trainable "soft prompt" tokens to the input. These tokens are optimized to steer the model's behavior for a specific task.
- Parameter count: Extremely small (a few thousand parameters per task).
- Communication cost: Negligible -- soft prompts are a few KB.
- Limitation: Less expressive than LoRA. Works well for task adaptation but poorly for domain-specific knowledge injection.
In federated settings, prompt tuning is useful for lightweight personalization (adjusting style, formality, domain terminology) but not for substantial model improvement.
Typical patterns
- Prompt tuning for domain adaptation.
- LoRA-style adapter updates.
- Cascading model strategies for device tiers.
Cascading model strategies
Not all devices in a fleet have the same compute capability. A cascading strategy deploys different model sizes to different device tiers:
- Tier 1 (high-end phones, tablets): Full LLM with LoRA adapters. On-device fine-tuning with 4-bit base model.
- Tier 2 (mid-range devices): Distilled smaller model. On-device inference only, no fine-tuning. Benefits from federated improvements via model updates from Tier 1.
- Tier 3 (low-end / IoT): Prompt-tuned compact model or API fallback.
Use Device Groups to segment your fleet by capability and assign appropriate model variants.
Octomil LLM configuration
Federated LoRA fine-tuning
from octomil import Federation, ModelRegistry
registry = ModelRegistry(api_key="edg_...")
# Register a base model with LoRA configuration
registry.upload_version_from_path(
model_id="domain-llm-v1",
file_path="./llama-3-1b.onnx",
version="1.0.0",
formats=["onnx", "coreml", "tflite"],
)
# Create a federated training job for the adapters
federation = Federation(api_key="edg_...", name="domain-llm-training")
result = federation.train(
model="domain-llm-v1",
algorithm="fedavg", # Standard aggregation works for adapters
rounds=100,
min_updates=10,
)
Configure LoRA adapter settings and strategy parameters via the dashboard (Training > Strategy Config) or the REST API:
Configure LoRA adapter:
- cURL
- Python
- JavaScript
curl -X PUT https://api.octomil.com/api/v1/federations/domain-llm-training/adapter \
-H "Authorization: Bearer edg_..." \
-H "Content-Type: application/json" \
-d '{
"type": "lora",
"rank": 8,
"alpha": 16,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
"dropout": 0.05
}'
import requests
response = requests.put(
"https://api.octomil.com/api/v1/federations/domain-llm-training/adapter",
headers={"Authorization": "Bearer edg_..."},
json={
"type": "lora",
"rank": 8,
"alpha": 16,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
"dropout": 0.05,
},
)
print(response.json())
const response = await fetch("https://api.octomil.com/api/v1/federations/domain-llm-training/adapter", {
method: "PUT",
headers: { "Authorization": "Bearer edg_...", "Content-Type": "application/json" },
body: JSON.stringify({
type: "lora",
rank: 8,
alpha: 16,
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"],
dropout: 0.05,
}),
});
const data = await response.json();
console.log(data);
Configure strategy:
- cURL
- Python
- JavaScript
curl -X PUT https://api.octomil.com/api/v1/federations/domain-llm-training/strategy \
-H "Authorization: Bearer edg_..." \
-H "Content-Type: application/json" \
-d '{
"algorithm": "fedavg",
"learning_rate": 1e-4,
"local_epochs": 3,
"max_grad_norm": 1.0,
"weight_decay": 0.01,
"round_config": {
"timeout_seconds": 900,
"min_samples_per_client": 50
},
"device_filter": {
"min_ram_mb": 6144,
"min_storage_mb": 2048
}
}'
import requests
response = requests.put(
"https://api.octomil.com/api/v1/federations/domain-llm-training/strategy",
headers={"Authorization": "Bearer edg_..."},
json={
"algorithm": "fedavg",
"learning_rate": 1e-4,
"local_epochs": 3,
"max_grad_norm": 1.0,
"weight_decay": 0.01,
"round_config": {
"timeout_seconds": 900,
"min_samples_per_client": 50,
},
"device_filter": {
"min_ram_mb": 6144,
"min_storage_mb": 2048,
},
},
)
print(response.json())
const response = await fetch("https://api.octomil.com/api/v1/federations/domain-llm-training/strategy", {
method: "PUT",
headers: { "Authorization": "Bearer edg_...", "Content-Type": "application/json" },
body: JSON.stringify({
algorithm: "fedavg",
learning_rate: 1e-4,
local_epochs: 3,
max_grad_norm: 1.0,
weight_decay: 0.01,
round_config: {
timeout_seconds: 900,
min_samples_per_client: 50,
},
device_filter: {
min_ram_mb: 6144,
min_storage_mb: 2048,
},
}),
});
const data = await response.json();
console.log(data);
Server-side API
# POST /api/v1/training-sessions
{
"model_id": "model-uuid",
"strategy": "fedavg",
"adapter_config": {
"type": "lora",
"rank": 8,
"alpha": 16,
"target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
},
"strategy_params": {
"learning_rate": 1e-4,
"local_epochs": 3,
"max_grad_norm": 1.0
},
"round_config": {
"min_clients": 10,
"max_rounds": 100,
"timeout_seconds": 900
},
"device_filter": {
"min_ram_mb": 6144,
"min_storage_mb": 2048
}
}
Monitoring LLM training
LLM federated training requires closer monitoring than standard FL due to the risk of catastrophic forgetting (the adapted model loses base capabilities) and the higher per-round cost.
Track these metrics in the Monitoring Dashboard:
- Adapter weight norms: Should grow slowly and plateau. Explosive growth indicates instability.
- Perplexity on held-out data: The primary quality metric. Should decrease smoothly.
- Base model capability retention: Run periodic evaluations on general benchmarks to detect catastrophic forgetting.
- Per-device training time: LLM fine-tuning is compute-intensive. Monitor for devices that consistently time out.
Device memory budgets
When planning a federated LLM deployment, calculate the per-device memory budget explicitly:
Total device RAM needed =
Base model (quantized) # e.g., 1GB for 1B params at 4-bit
+ LoRA adapters (full precision) # e.g., 10MB for rank-8 adapters
+ Optimizer states (for adapters) # e.g., 20MB (Adam: 2x adapter size)
+ Activation memory (per batch) # e.g., 200MB (depends on sequence length)
+ OS and app overhead # e.g., 500MB
-----------------------------------
~ 1.7GB total
For a device with 6GB RAM, this leaves ~4.3GB for other applications. Set min_ram_mb in the device filter to ensure only capable devices are selected for training rounds.
Split learning alternative
For models too large for on-device fine-tuning (even with LoRA + quantization), split learning divides the model at a chosen layer: bottom layers run on-device, top layers on the server. The device sends intermediate activations (not raw data) to the server, which completes the forward pass and returns gradients.
Trade-offs: Split learning requires per-batch network round trips (needs low-latency connectivity) and intermediate activations can leak information. Apply differential privacy noise at the split point to mitigate this. Use split learning only when LoRA is infeasible and network conditions allow synchronous communication.
Best practices
-
Start with LoRA rank 4-8. Higher ranks increase expressiveness but also communication cost and memory. Rank 8 is sufficient for most domain adaptation tasks. Only increase if you see clear underfitting.
-
Always clip gradients. LLM fine-tuning is prone to gradient spikes, especially in early rounds. Set
max_grad_normto 1.0 as a default and only increase if training is too slow. -
Use FedAvg for adapter aggregation. LoRA adapters are small enough that the communication overhead is manageable, and the adapters' low dimensionality reduces the impact of client drift. SCAFFOLD or robust aggregation may be warranted for extreme heterogeneity but adds complexity that is usually unnecessary.
-
Filter devices by capability. Not every device in your fleet can run LLM fine-tuning. Use device groups and hardware filters aggressively. It is better to train on 10 capable devices than to include 50 devices where 40 will time out.
-
Evaluate for catastrophic forgetting. After every N rounds, evaluate the adapted model on general-purpose benchmarks (not just the target task). If general capability degrades, reduce the learning rate or add a regularization term.
-
Deploy adapters, not full models. The base model is distributed once (or pre-installed). Only the LoRA adapters are updated per round. This reduces deployment bandwidth by 100-1000x. Use the Model Catalog to version base models and adapters independently.
-
Use canary rollouts for adapter updates. LLM behavior changes can be subtle and hard to catch with automated metrics. Deploy adapter updates to a canary group first using Model Rollouts and monitor user-facing quality metrics before full rollout.
Implementation path
- Validate foundation setup from Quickstart.
- Use round orchestration from Training Rounds.
- Protect production with Model Rollouts.
Further reading
- Model Lifecycle -- versioning base models and adapters
- Advanced FL Configuration -- model compression and optimization techniques
- Advanced FL Concepts -- communication efficiency and async training
- Device Groups -- segmenting devices by capability
- Privacy Guide -- differential privacy for split learning and adapter updates
- iOS SDK -- CoreML integration for on-device LLMs
- Android SDK -- TFLite/NNAPI for on-device inference