Custom Trainer Integration¶
AstraFlow’s trainer is a swappable component. You can replace the built-in PPO trainer with your own training framework as long as it speaks the correct HTTP protocol. This guide documents the minimal API your custom trainer must implement.
This guide assumes TCP weight transfer mode, which is the recommended mode for custom trainers because the trainer only needs to talk to AstraFlow — no direct RaaS communication required for data or coordination.
Architecture¶
A trainer has two communication surfaces:
┌───────────────────────┐
Outbound HTTP │ │ Inbound HTTP
(3 endpoints) │ Custom Trainer │ (3 endpoints)
│ │
Trainer ────────►│ │◄──────── RaaS
calls AstraFlow │ ┌─────────────────┐ │ calls trainer's
│ │ Sender Agent │ │ weight sender
│ │ (reusable lib) │ │
│ └─────────────────┘ │
└───────────────────────┘
Registration ──► AstraFlow RaaS ──► Weight connection
Data pull ──► Service instances Weight pull
Version notify► (1..N)
Outbound (trainer → AstraFlow): 3 HTTP calls for registration, data, and version notification.
Inbound (RaaS → trainer): 3 HTTP endpoints served by the weight sender agent. RaaS connects to pull updated weights via TCP.
The weight sender agent (astraflow.train_worker.weight_transfer.sender_agent)
is provided as a reusable library. Custom trainers do not need to reimplement
the TCP/ZMQ transfer machinery.
API Reference¶
Wire formats¶
Pickle / cloudpickle (
Content-Type: application/octet-stream) — all three Trainer → AstraFlow endpoints. Payload is a Python dict beforecloudpickle.dumps.JSON — all three RaaS → Trainer sender endpoints.
Note on envelope shapes: unlike RaaS, AstraFlow’s pickle endpoints do
not use a nested {ok, result} wrapper. Responses are flat dicts
with an ok field alongside whatever other keys the endpoint returns.
GET /batch returns its payload with no ok at all — the HTTP status
is the only success signal. Below, python code fences show the
Python dict before pickling; json fences show literal JSON text.
Trainer → AstraFlow (outbound)¶
These are HTTP calls your trainer makes to the AstraFlow service.
Registration: POST /ready¶
Signal that the trainer is ready to receive data. AstraFlow starts data acquisition only after both RaaS and trainer have signalled readiness.
Request fields:
Field |
Type |
Required |
Description |
|---|---|---|---|
|
|
Yes |
Number of examples per training batch |
|
|
No |
Model identifier for multi-model setups |
|
|
Yes (TCP) |
|
|
|
No |
Weight version to resume from after checkpoint recovery |
Request (Python dict, pickled):
{
"train_batch_size": 512,
"model_id": "model0", # omit for single-model
"sender_endpoint": "10.0.0.1:18861",
"recovered_version": 0, # omit for fresh start
}
Response (flat, no envelope — pickled):
{"ok": True}
Data: GET /batch¶
Pull a training batch. Blocks until sufficient data is available in the buffer.
Query parameters (not a pickle body — real URL query string):
Param |
Type |
Required |
Description |
|---|---|---|---|
|
|
No |
Which model’s data to pull (multi-model) |
|
|
No |
Trainer’s current version (used for multi-model barrier sync) |
Request — a plain HTTP GET:
GET /batch?model_id=model0&version=42
Response — pickled Python dict, no envelope (the payload is the batch directly):
{
"batch": {
"input_ids": Tensor, # [batch, seq_len]
"rewards": Tensor, # [batch, seq_len]
"logprobs": Tensor, # [batch, seq_len]
"loss_mask": Tensor, # [batch, seq_len]
# ... other fields depending on workflow
},
"buffer_stats": {
"buffer/size": 1024,
"buffer/staleness_mean": 1.2,
# ... other buffer/filter metrics for wandb
},
}
Weights: POST /notify_version¶
Notify AstraFlow that new weights are available. AstraFlow’s Python-side
barrier waits until every registered model_id has reported the same
version, then fans out one POST /notify_version per model to every
live RaaS instance.
Request fields:
Field |
Type |
Required |
Description |
|---|---|---|---|
|
|
Yes |
New model version after training step |
|
|
No |
If |
|
|
No |
Model identifier (multi-model) |
Request (Python dict, pickled):
{
"version": 42,
"run_eval": False,
"model_id": "model0", # omit for single-model
}
Response (flat, no nested envelope — pickled):
# Non-eval step
{
"ok": True,
"eval_results": None,
"weight_transfer_info": {"use_full": 1}, # only present for TCP mode
}
# Eval step
{
"ok": True,
"eval_results": {
"math/acc": 0.512,
"math/pass@1": 0.47,
# ... benchmark-specific keys
},
"weight_transfer_info": {"use_full": 0},
}
RaaS → Trainer (inbound, weight sender)¶
These are HTTP endpoints that RaaS calls on the trainer’s weight sender server. The sender agent library handles these automatically — listed here for completeness.
GET /get_buffer_info¶
RaaS queries the trainer’s model tensor layout to allocate its receive buffer. Called once when a RaaS instance first connects.
Response (JSON):
{
"single_buffer_length": 3489660928,
"tensors_meta": [
["model.layers.0.self_attn.q_proj.weight", [[2048, 2048], "bfloat16"]],
["model.layers.0.self_attn.k_proj.weight", [[512, 2048], "bfloat16"]]
]
}
POST /register_sglang_instance¶
RaaS registers itself as a weight receiver. Both sides exchange TCP session
IDs and buffer pointers to establish a persistent TCP connection. Called
once per RaaS instance, after get_buffer_info.
Request (JSON):
{
"sglang_http_host": "10.0.0.5",
"sglang_http_port": 19190,
"session_ids": ["abc123"],
"buffer_ptr": 140234567890,
"buffer_length": 3489660928,
"zmq_endpoint": "10.0.0.5",
"zmq_port": 45678,
"handshake_ports": [21000]
}
Response (JSON):
{
"trainer_global_rank": 0,
"trainer_world_size": 2,
"trainer_session_ids": [["def456"]],
"trainer_buffer_ptr": 140111222333,
"trainer_buffer_length": 6979321856,
"trainer_hostname": "10.0.0.1",
"trainer_rpc_port": 21000
}
POST /request_transfer¶
RaaS requests a weight pull. The sender acquires a buffer lock, performs TCP bulk copy from the active buffer half to the receiver, and sends a ZMQ completion signal. Called every training step (or when RaaS detects it is behind the latest version).
Request (JSON):
{
"instance_id": "10.0.0.5:19190"
}
Response (JSON):
{
"ok": true,
"version": 42
}
Connection Lifecycle¶
The weight transfer connection between a trainer and each RaaS instance follows this lifecycle:
RaaS Trainer Sender
│ │
│ 1. Query tensor layout │
├──── GET /get_buffer_info ─────────────────────►│
│◄─── {tensors_meta, buffer_length} ─────────────┤
│ │
│ 2. Allocate receive buffer locally │
│ 3. Create TCP engine locally │
│ 4. Start ZMQ listener locally │
│ │
│ 5. Exchange TCP details, establish connection │
├──── POST /register_sglang_instance ───────────►│
│◄─── {trainer_session_ids, rpc_port, ...} ──────┤
│ │
│ ═══════ persistent TCP connection open ════════│
│ │
│ 6. Per training step: pull weights │
├──── POST /request_transfer ───────────────────►│
│◄════ TCP bulk copy (shared-mem → shared-mem) ══┤
│◄──── ZMQ "transfer complete" ──────────────────┤
│ │
│ (repeat step 6 each training step) │
Steps 1-5 happen once per RaaS instance. Step 6 repeats every training step. The TCP connection persists across all transfers.
Weight Transfer: Double-Buffer Design¶
The sender agent uses a double-buffer in shared memory (/dev/shm) so
that weight copying and weight serving never block each other:
Trainer writes updated weights to the inactive buffer half via
copy_weights_to_buffer().Trainer atomically swaps the active buffer index — new transfers now read from the freshly-written half.
RaaS calls
POST /request_transfer— sender reads from the active half and performs TCP bulk copy to the receiver.RaaS saves the received bytes as safetensors in
/dev/shmand tells the inference engine to reload.
This means training can continue writing the next set of weights while RaaS is still pulling the current set.
Complete Example¶
from astraflow.train_worker.trainer.astraflow_client import AstraFlowClient
# --- Startup ---
# 1. Start the weight sender agent (reusable library).
# This launches a subprocess that serves /get_buffer_info,
# /register_sglang_instance, and /request_transfer.
sender = start_sender_agent(model)
# 2. Connect to AstraFlow and signal readiness.
client = AstraFlowClient(
service_url="http://astraflow-host:8000",
)
client.initialize() # wait for AstraFlow service to be ready
client.signal_ready(
train_batch_size=512,
sender_endpoint=sender.endpoint, # e.g. "10.0.0.1:18861"
)
# --- Training Loop ---
version = 0
while training:
# Pull batch (blocks until data is available)
batch, buffer_stats = client.get_batch(version=version)
# Your training logic
loss = train_step(batch)
# Copy updated weights to shared-memory buffer
copy_weights_to_buffer()
sender.notify_buffer_ready(version + 1)
# Notify AstraFlow — it broadcasts to all RaaS instances
version += 1
should_eval = (version % eval_freq == 0)
eval_results = client.notify_version(
version=version,
run_eval=should_eval,
)
if eval_results:
log_eval(eval_results)
# --- Shutdown ---
client.drain_pending_notifications()
client.shutdown_service()
API Summary¶
Direction |
Category |
Method |
Endpoint |
Frequency |
|---|---|---|---|---|
Trainer → AstraFlow |
Registration |
|
|
Once |
Trainer → AstraFlow |
Data |
|
|
Every step |
Trainer → AstraFlow |
Weights |
|
|
Every step |
RaaS → Trainer |
Weights |
|
|
Once per RaaS |
RaaS → Trainer |
Weights |
|
|
Once per RaaS |
RaaS → Trainer |
Weights |
|
|
Every step |
Total: 6 APIs. The 3 inbound endpoints are handled by the reusable sender agent library.