Skip to main content

Python SDK

Programmatic access to federated learning, model registry, and rollouts from Python.

For inference and CLI commands (octomil serve, octomil deploy, etc.), install the binary instead.

Install

pip install octomil-sdk

Quick Start

from octomil import Federation, FederatedClient

# Server-side: orchestrate federated training
federation = Federation(api_key="edg_...")
result = federation.train(model="my-classifier", rounds=10, min_updates=100)

# Edge device: participate in training
client = FederatedClient(api_key="edg_...", device_identifier="hospital-001")
client.register()
client.train(model="my-classifier", data="s3://hospital-data/patients.parquet")

Federation

Orchestrate federated training across devices.

federation = Federation(api_key="edg_...", name="my-federation")

train()

result = federation.train(
model="my-classifier",
rounds=10,
min_updates=100,
algorithm="fedavg", # fedavg (default)
base_version="1.0.0", # default: latest
new_version="1.1.0", # default: auto-generated
update_format="delta", # "delta" or "weights"
)

deploy()

federation.deploy(
version="1.1.0",
rollout_percentage=10,
target_percentage=100,
increment_step=10,
)

FederatedClient

Participate in federated training from edge devices.

client = FederatedClient(api_key="edg_...", device_identifier="hospital-001")
client.register()

train()

Accepts multiple data sources — features are auto-detected and aligned:

# S3
client.train(model="cancer-detection", data="s3://bucket/patients.parquet")

# Local file
client.train(model="cancer-detection", data="/data/patients.csv", target_col="diagnosis")

# DataFrame
client.train(model="cancer-detection", data=hospital_df)

# Pre-trained weights
client.train(model="my-model", data=model.state_dict(), sample_count=1000)

Cloud credentials via environment variables: AWS_ACCESS_KEY_ID, GOOGLE_APPLICATION_CREDENTIALS, AZURE_STORAGE_CONNECTION_STRING.

train_from_remote()

Pull model, train locally, submit updates:

def train_locally(base_state_dict):
model = MyModel()
model.load_state_dict(base_state_dict)
train_one_epoch(model, local_dataloader)
return model.state_dict(), len(local_data), {"loss": 0.42}

client.train_from_remote(model="my-classifier", local_train_fn=train_locally, rounds=5)

pull_model()

model_bytes = client.pull_model(model="my-classifier", version="1.0.0", format="pytorch")

ModelRegistry

Version, upload, and convert models.

from octomil import ModelRegistry

registry = ModelRegistry(api_key="edg_...")

# Create model
model = registry.ensure_model(name="mnist", framework="pytorch", use_case="image_classification")

# Upload version with auto-conversion
registry.upload_version_from_path(
model_id=model["id"],
file_path="model.pt",
version="1.0.0",
formats="onnx,tflite,coreml",
)

# Publish
registry.publish_version(model_id=model["id"], version="1.0.0")

# Rollout
registry.create_rollout(model_id=model["id"], version="1.0.0", rollout_percentage=10)

Rollouts

# Create
rollout = client.rollouts.create(model_id="model-123", version="2.0.0", rollout_percentage=10)

# Manage
client.rollouts.advance(model_id="model-123", rollout_id=rollout.id)
client.rollouts.pause(model_id="model-123", rollout_id=rollout.id)
client.rollouts.resume(model_id="model-123", rollout_id=rollout.id)
client.rollouts.update_percentage(model_id="model-123", rollout_id=rollout.id, percentage=50)

# Inspect
client.rollouts.list(model_id="model-123", status_filter="active")
client.rollouts.get_affected_devices(model_id="model-123", rollout_id=rollout.id)
client.rollouts.get_status_history(model_id="model-123", rollout_id=rollout.id)

# Delete
client.rollouts.delete(model_id="model-123", rollout_id=rollout.id, force=True)

Error Handling

from octomil import OctomilClientError

try:
client.train(model="nonexistent", data=weights)
except OctomilClientError as e:
print(f"Training failed: {e}")

Cloud Inference

The Client class provides a high-level interface for cloud and local inference. It reads OCTOMIL_API_KEY, OCTOMIL_ORG_ID, and OCTOMIL_API_BASE from environment variables, or accepts them as constructor arguments.

import octomil

client = octomil.Client(api_key="edg_...", org_id="my-org")

# One-call predict: downloads the model, loads it, and runs inference
result = client.predict(
"phi-4-mini",
messages=[{"role": "user", "content": "Summarise federated learning in one paragraph."}],
max_tokens=256,
temperature=0.7,
top_p=1.0,
)
print(result) # result is a str
print(result.metrics) # access latency and token metrics

load_model()

For repeated inference, load the model once and call predict() directly:

model = client.load_model(
"phi-4-mini",
version="1.0.0", # default: latest
engine="mlx-lm", # optional engine override
cache_size_mb=2048,
)

prediction = model.predict(
[{"role": "user", "content": "Hello"}],
max_tokens=128,
)

Streaming Inference

Cloud streaming (SSE)

Stream tokens from the Octomil cloud endpoint via Server-Sent Events. This does not download or run models locally.

# Sync
for token in client.stream_predict(
"phi-4-mini",
[{"role": "user", "content": "Write a haiku about edge AI."}],
parameters={"temperature": 0.8, "max_tokens": 64},
timeout=120.0,
):
print(token.token, end="", flush=True)
if token.done:
break
# Async
import asyncio

async def main():
async for token in client.stream_predict_async(
"phi-4-mini",
"Explain quantum computing briefly.",
parameters={"max_tokens": 128},
):
print(token.token, end="", flush=True)

asyncio.run(main())

Each StreamToken contains:

FieldTypeDescription
tokenstrThe generated text fragment
doneboolTrue on the final token
providerstr | NoneWhich backend served the request
latency_msfloat | NoneServer-side latency for this token
session_idstr | NoneUnique session identifier

Standalone streaming functions

You can also call the streaming functions directly without a Client:

from octomil import stream_inference, stream_inference_async

for token in stream_inference(
server_url="https://api.octomil.com/api/v1",
api_key="edg_...",
model_id="phi-4-mini",
input_data=[{"role": "user", "content": "Hello"}],
parameters={"temperature": 0.7},
timeout=120.0,
):
print(token.token, end="")

Local streaming (on-device)

For on-device streaming with automatic timing instrumentation, use predict_stream():

import asyncio

async def stream_local():
async for chunk in client.predict_stream(
"phi-4-mini",
messages=[{"role": "user", "content": "Hello"}],
max_tokens=256,
temperature=0.7,
):
print(chunk, end="", flush=True)

asyncio.run(stream_local())

Embeddings

Generate dense vector embeddings via the Octomil cloud endpoint. Useful for semantic search, clustering, and RAG pipelines.

result = client.embed(
model_id="nomic-embed-text",
input="Federated learning preserves data privacy.",
timeout=30.0,
)

print(result.embeddings) # list[list[float]] — one vector per input string
print(result.model) # model name used
print(result.usage) # EmbeddingUsage(prompt_tokens=..., total_tokens=...)

Batch multiple strings in a single call:

result = client.embed(
model_id="nomic-embed-text",
input=[
"On-device inference reduces latency.",
"Smart routing optimises cost and quality.",
"Federated learning keeps data on-device.",
],
)

for i, vec in enumerate(result.embeddings):
print(f"String {i}: {len(vec)}-dim vector")

Standalone embed() function

from octomil import embed

result = embed(
server_url="https://api.octomil.com/api/v1",
api_key="edg_...",
model_id="nomic-embed-text",
input=["hello", "world"],
timeout=30.0,
)

Routing

Route queries to the most appropriate model based on estimated complexity. Simple queries (greetings, short factual questions) go to the smallest/fastest model; complex queries (code generation, multi-step reasoning) go to the largest/most capable model.

Basic routing

from octomil import QueryRouter, ModelInfo

models = {
"smollm-360m": ModelInfo(name="smollm-360m", tier="fast", param_b=0.36),
"phi-4-mini": ModelInfo(name="phi-4-mini", tier="balanced", param_b=3.8),
"llama-3.2-3b": ModelInfo(name="llama-3.2-3b", tier="quality", param_b=3.0),
}

router = QueryRouter(
models,
strategy="complexity",
thresholds=(0.3, 0.7), # complexity boundaries: fast | balanced | quality
enable_deterministic=True, # intercept arithmetic without invoking a model
)

decision = router.route([{"role": "user", "content": "Write a REST API in Python"}])
print(decision.model_name) # "llama-3.2-3b" (high complexity)
print(decision.complexity_score) # 0.0–1.0
print(decision.tier) # "quality"
print(decision.fallback_chain) # ordered fallback models

Deterministic tier (no model needed)

Arithmetic, unit conversions, and other trivial queries are answered without invoking any model:

from octomil.routing import check_deterministic

result = check_deterministic("what is 2+2?")
if result is not None:
print(result.answer) # "4"
print(result.method) # "arithmetic"
print(result.confidence) # 1.0

Auto-assign tiers

When you have an ordered list of models (smallest to largest), auto-assign tiers:

from octomil import assign_tiers

models = assign_tiers(["smollm-360m", "phi-4-mini", "llama-3.2-3b"])
router = QueryRouter(models)

Decomposed routing

For multi-task queries, each sub-task can be routed independently:

decision = router.route_decomposed([
{"role": "user", "content": "Explain quantum computing and write a Python sort function"}
])

if hasattr(decision, "sub_decisions"):
for sub, task in zip(decision.sub_decisions, decision.tasks):
print(f" Task: {task.text} -> {sub.model_name} ({sub.tier})")

Fallback

next_model = router.get_fallback("phi-4-mini")
# Returns the next model to try if phi-4-mini fails

Batch Inference

The RequestQueue serialises concurrent inference requests to prevent engine contention. Requests are processed FIFO, and each caller receives results via an async future.

from octomil.batch import RequestQueue, QueueFullError, QueueTimeoutError

queue = RequestQueue(max_depth=32, timeout=60.0)
queue.start()

# Submit a non-streaming request
result = await queue.submit_generate(request, generate_fn=my_engine.generate)

# Submit a streaming request
async for chunk in queue.submit_generate_stream(request, generate_stream_fn=my_engine.stream):
print(chunk, end="")

# Check queue stats
stats = queue.stats()
print(f"Pending: {stats.pending}, Active: {stats.active}, Max: {stats.max_depth}")

# Shutdown
await queue.stop()

Quality Evaluation

The SDK reports inference events automatically when using StreamingInferenceClient from the inner octomil.inference module. Each streaming session tracks:

  • Time to first chunk (TTFC) — latency until the first token arrives
  • Average chunk latency — mean inter-token time
  • Throughput — chunks per second
  • Total duration — wall-clock time for the full generation

Events (generation_started, generation_completed, generation_failed) are reported to POST /api/v1/inference/events for server-side quality dashboards.

Gotchas

  • SDK is for orchestration, not inference — use octomil serve for local inference. The Python SDK manages models, training, and rollouts via the API.
  • ensure_model is idempotent — safe to call in scripts and CI. It creates the model if it doesn't exist, or returns the existing one.
  • train_from_remote blocks per round — the function runs synchronously for each round. Use threading or async wrappers if you need concurrent training across multiple models.
  • Cloud credentials via env vars — S3/GCS/Azure data sources require standard cloud credentials in the environment. The SDK does not store cloud credentials.
  • formats triggers local conversion — uploading with formats="onnx,coreml,tflite" converts locally via CLI before uploading. The version stays in converting state until all formats are ready.
  • Client caches loaded models — calling predict() or predict_stream() with the same model name reuses the loaded backend. Call client.dispose() to release all cached models.
  • Streaming requires httpxstream_predict() and embed() use httpx under the hood. Install it with pip install httpx.