WeightManager¶
The WeightManager (astraflow/weight_manager/) is an independent component
that handles all weight transfer between Trainer and RaaS.
Design Principle: Independent Transport Layer¶
WeightManager is a shared component — both Trainer and RaaS import from it, but neither depends on the other. This keeps the three swappable components (Trainer, RaaS, WeightManager) cleanly separated:
train_worker/ ──imports──► weight_manager/
raas/ ──imports──► weight_manager/transfer/
train_worker/ ✗ raas/ (no cross-dependency)
raas/ ✗ train_worker/ (no cross-dependency)
The Trainer’s integration surface is a single call:
wm.offload(model.named_parameters(), version, rank, world_size)
Key Components¶
WeightManager— Main class that owns buffer allocation, GPU→CPU copy (shard-direct and all-gather paths), sender agent lifecycle, and double-buffer swap. The trainer callsoffload()once per step.Transfer modes —
POST /request_transferaccepts amodeoffull(sends the entire model) ordelta(sends only changed elements, ~1-2% of the model). See Delta Weight Transfer for details.TransferAgent— Subprocess on the Trainer side that exposes HTTP endpoints and serves TCP weight pulls from RaaS.TCPTransferEngine— TCP engine with 6 parallel streams andsendfile()zero-copy. Used by both sender (Trainer) and receiver (RaaS).TransferBuffer— CPU byte buffer for receiving weights on the RaaS side.
Weight Sync Flow¶
Trainer calls
wm.offload()— GPU weights are copied to the inactive half of a shared-memory double buffer, then the buffer index is swapped.Trainer notifies Dataflow via
POST /notify_version.Dataflow’s Python-side version barrier waits for all model_ids (multi-model), then fans out one
POST /notify_versionper model per RaaS instance. Each call carries{model_id, version, sender_endpoint}.On each RaaS, the manager acquires a per-model lock and calls
POST /request_transferon the sender agent to pull weights via 6 parallel TCP streams.RaaS saves received bytes as safetensors to
/dev/shm, pauses inference engines, loads weights, and resumes.
Trainer WeightManager SenderAgent RaaS
│ │ │ │
│ offload(params, v) │ │ │
│───────────────────────────►│ │ │
│ │ GPU→CPU shard copy │ │
│ │ swap buffer index │ │
│ │ mp.Queue: buffer_ready │ │
│ │────────────────────────►│ │
│◄───────────────────────────│ (returns) │ │
│ next train_step ... │ │ │
│ │ │◄── /request_transfer
│ │ │ TCP send ───────►│
│ │ │ ZMQ done ───────►│
│ │ │ safetensors
│ │ │ load + resume
Timeline: What Blocks What¶
The key insight is that weight transfer never blocks the trainer GPU
directly. The trainer’s GPU is only blocked during the offload() call
(GPU→CPU copy). Everything else — TCP transfer, RaaS weight loading —
happens asynchronously. The trainer is gated by data availability
(the get_batch call), not weight transfer.
Single Training Step¶
Time ──────────────────────────────────────────────────────────────────────────►
TRAINER (GPU)
│ train_step N │offload│save│notify│ get_batch │ train_step N+1
│[===========================][=====]│chkp│async │ (blocks │[===============
│ forward/backward/optim │GPU→CPU│ │ │ until │
│ │~2s │ │ │ service │
│ │ │ │ │ version │
│ │ not blocked ──► │ catches │
│ │ (trainer does │ up) │
│ │ save, log, │[==========]│
│ │ etc.) │ blocked by │
│ │ │ data, not │
│ │ │ by weight │
│ │ │ transfer │
SENDER AGENT (CPU subprocess)
│ · · · · · · · · · · · · · ·│swap│ack│ delta compute (async) │done│ · · · ·
│ idle │idx │ │[======================]│evt │ idle
│ │ │ │ compare halves (numpy) │ │
│ │ │ │ ~1.8s for 1.7B │ │
DATAFLOW SERVICE (CPU)
│ · · · · · · · · · · · · · · · · │version│RaaS weight load │version │serve
│ │barrier│[================]│updated! │batch
│ │ │notify_version (per model) │─────►
│ │ │+ TCP pull + load │ │
RAAS (GPU)
│ generating rollouts · · · · · · · · · · │pull│pause│load│resume│ generating
│[rollout][rollout][rollout] · · · · · · ·│TCP │[===]│[==]│ │[rollout]..
│ │~3s │ │~5s │ │
Key observations:
Trainer GPU blocks only during
offload()(~0.5s for 1.7B, shard copy). After offload, the trainer saves checkpoints, logs stats, waits for async delta to finish, then firesnotify_version_async.Delta computation is async (~1.8s for 1.7B). Runs in the sender agent subprocess, overlapped with trainer checkpoint/logging. The trainer calls
wait_delta_ready()beforenotify_versionto ensure delta is ready when RaaS pulls.get_batchis the real synchronization point. Dataflow won’t serve a batch untilservice_version >= trainer_version. The service version updates only after the RaaS weight load completes. So the trainer blocks on data availability, not on weight transfer directly.RaaS blocks during
pause → load → resume(~3.7s). During this window it cannot serve rollout requests. The TCP pull itself does NOT block inference — RaaS pulls into a separate buffer first.Double buffer eliminates trainer↔TCP contention. The trainer writes to one half while the sender serves the other half. A guard barrier at the start of
offload()ensures the previous async delta has finished reading before overwriting.
Multi-Model Timeline¶
With two models (e.g. actor + verifier), each trainer independently
offloads weights and fires notify_version_async. The version barrier
and weight loading happen on the Dataflow side. Trainers are only
blocked when they request the next batch:
Time ──────────────────────────────────────────────────────────────────────────►
TRAINER model0 (actor)
│ train_step │offload│save│notify│ get_batch(v=N+1) │ train_step N+1
│[=================][=====]│chkp│async │ blocks until │[===============
│ │ │ │ service catches up │
│ │ free ───► │[==================]│
│ │ │ blocked by data │
TRAINER model1 (verifier)
│ train_step (slower) │offload│save│notify│ get_batch(v=N+1) │ ..
│[=============================][=====]│chkp│async │ blocks until │[==
│ │ │ │ service catches up│
│ │ free ──► │[=================]│
DATAFLOW SERVICE
│ · · · · · · · · · · · · · · · · · · ·│barrier met!│RaaS load │version │
│ │(both v=N+1)│[=========]│updated!│
│ │ │ │serve │
│ │ │ │batches │
RAAS
│ generating · · · · · · · · · · · · · · · · · · · ·│pull │pause│load│resume
│ │both │[===]│both│
│ │mdls │ │mdls│
│ ◄── batches served ──►
Key observations:
Neither trainer blocks on weight transfer. Both fire
notify_version_asyncand continue with checkpoint/logging work.The version barrier is trainer-to-trainer synchronization, not weight transfer synchronization. It ensures all model_ids reach the same version before triggering the RaaS weight load.
get_batchis the gating point. Each trainer blocks on its next batch request until Dataflow’sservice_versioncatches up (after the version barrier + RaaS load completes).Trainer GPU is free between
offload()andget_batch()— this window is used for checkpoint saving, wandb logging, and other I/O.
Eval Steps (Synchronous Exception)¶
On eval steps, notify_version is called synchronously because the
trainer needs eval results before continuing. This is the only case where
the trainer explicitly waits for the weight load + eval to complete:
│offload│save│notify_version(sync, eval) │ eval results │ next step
│[=====]│chkp│[==================================]│ returned │[=========
│ │ │ barrier + RaaS load + eval run │ │
GPU→CPU Copy Strategies¶
WeightManager automatically selects the optimal copy strategy based on FSDP parameter placement:
Shard copy (fast path) — When all parameters use Shard(0) placement
(the standard FSDP2 case), each rank copies only its local shard to the
correct offset. No cross-rank communication. All ranks write in parallel.
All-gather (fallback) — When any parameter has a different placement, FSDP all-gathers the full tensors. Only rank 0 copies to the buffer. Slower but handles any sharding strategy.
Shard copy |
All-gather |
|
|---|---|---|
Network |
None |
Full model over NCCL |
PCIe per rank |
|
|
Parallelism |
All N ranks |
Rank 0 only |
Double Buffer¶
The shared-memory buffer in /dev/shm is 2× model size. The trainer
writes to the inactive half while the sender serves the active half
over TCP — no locking between them:
Step N: trainer writes Half 0, TCP reads Half 1
Step N+1: trainer writes Half 1, TCP reads Half 0
The buffer index swap is a single Python int assignment (atomic under GIL).
Optimizations¶
Zero-copy
sendfile()for TCP transfermlockand transparent huge pages for shared-memory buffers6 parallel TCP streams for throughput
madvise(MADV_SEQUENTIAL | MADV_WILLNEED)hintsCUDA host registration for pinned DMA transfers
Project Structure¶
astraflow/weight_manager/
__init__.py ← exports WeightManager, WeightManagerConfig
weight_manager.py ← main class: buffer mgmt, GPU→CPU copy, sender lifecycle
config.py ← WeightManagerConfig
transfer/ ← shared transport layer
config.py ← TransferEngineConfig, SenderAgentConfig, ReceiverAgentConfig, TransferStatus, ReceiverInfo
sender_agent.py ← sender subprocess (TransferAgent): HTTP, TCP, ZMQ
transfer_engine.py ← TCPTransferEngine (6-stream, sendfile)
receiver_agent.py ← TransferBuffer
Sender Agent HTTP API¶
The sender agent runs as a subprocess on the trainer node and exposes these endpoints (used by RaaS to pull weights):
Method |
Endpoint |
Purpose |
|---|---|---|
|
|
Current weight version |
|
|
Buffer size and tensor metadata |
|
|
Supported strategies, delta readiness |
|
|
Register a RaaS receiver |
|
|
Pull weights via TCP (mode: full or delta) |
Multi-Model Training¶
In multi-model training (e.g. actor + verifier), each model has its own WeightManager with its own sender agent and shared-memory buffer. Dataflow coordinates them via a version barrier — all models must reach the same version before any RaaS loads weights. This prevents serving rollouts with mismatched model versions.
Multi-RaaS¶
Weight notifications fan out to all RaaS instances in parallel via
RaaSPool.notify_version(model_id, version, sender_endpoint)
(astraflow/dataflow/raas_pool.py). Each instance independently
pulls, pauses, loads, and resumes. New instances joining mid-training
catch up to the current weights via AstraFlowService.catchup_raas()
before entering the live pool.