Distributed Systems · Reference

Collective Communication
for Large-Scale LLM Training

The data-movement primitives that decide whether ten thousand GPUs behave like one machine or ten thousand machines waiting on each other. A technical walkthrough of every collective, the parallelism strategy that summons it, and its cost model — with each algorithm animated step by step.

7core collectives
4Dparallelism context
2(N−1)/Nring all-reduce passes
α–βcost model throughout
How to read this. The assumption is familiarity with plain data parallelism — knowing why all-reduce exists because you've seen it average gradients. The piece builds outward from there: how all-reduce is actually implemented at the wire level, and then the collectives that only appear once a model no longer fits on one device (reduce-scatter, all-gather, all-to-all, point-to-point).
SECTION 01

The mental model: parallelism summons collectives

You never pick a collective directly. You pick a way to split the model across devices, and that choice generates a specific communication pattern. Get this mapping straight and the rest is detail.

Every collective in this report is the consequence of one decision: what did you split, and across which devices? When work is divided, the pieces have to be reconciled, and the shape of that reconciliation is the collective. Here is the whole map before we zoom in:

ParallelismWhat's splitCommunication it generatesDominant collective
Data (DP)The batch. Every device holds a full model replica.Gradients must be averaged across replicas each step.all-reduce
Sharded DP (ZeRO/FSDP)The batch and the params/optimizer state.Gather full layer before use; scatter gradients after.all-gather + reduce-scatter
Tensor (TP)Individual weight matrices, within a layer.Partial activations summed mid-forward and mid-backward.all-reduce (intra-node)
Pipeline (PP)The layer stack, into sequential stages.Activations forward, gradients backward, stage to stage.point-to-point send/recv
Expert (EP)MoE experts across devices.Tokens routed to wherever their expert lives.all-to-all
Why this framing matters A common failure mode is reciting "all-reduce sums gradients" without being able to say why a different parallelism produces a different collective. Deriving the collective from the split means never having to memorize the table — it can be reconstructed from first principles.
Remember this
What you split is what you must reconcile.
Every collective is the repair bill for a division of labor. Split the batch → reconcile gradients (all-reduce). Split the params → reconcile by gathering them (all-gather). Split the experts → reconcile by routing tokens (all-to-all). Find the split, and the collective follows.
SECTION 02

The α–β cost model

One small piece of theory makes every "why is this algorithm optimal" answer fall out automatically. Worth memorizing cold.

The time to send a message of n bytes between two devices is modeled as two terms:

T(n) = α + n·β

A third term, γ (gamma), is sometimes added for the per-byte cost of the reduction arithmetic (the actual summing), but on modern hardware the network dominates, so we'll keep γ in the background.

The whole game of collective-algorithm design is this tradeoff:

The key result you should be able to state Ring all-reduce moves 2·(N−1)/N · n bytes per device regardless of N — it is bandwidth-optimal. Its cost is 2(N−1)α + 2·(N−1)/N·nβ. The downside is the 2(N−1) latency term: with thousands of GPUs the α's pile up, which is exactly why large clusters switch to hierarchical/tree schemes for the latency-sensitive part.
ring all-reduce: 2(N−1)α + 2·(N-1)/N·nβ tree all-reduce: 2·log₂(N)·α + 2nβ all-gather (ring): (N−1)α + (N-1)/N·nβ
Remember this
Ring for bytes, tree for hops.
Big messages are bandwidth-bound → ring (every link moves the minimum, but pays 2(N−1) latency steps). Small messages are latency-bound → tree/recursive-doubling (log N steps, slightly more bytes). α counts hops; β counts bytes. Pick the algorithm that minimizes whichever term dominates.
SECTION 03

All-Reduce & the ring decomposition

The collective you already use. Here's what NCCL is actually doing underneath loss.backward() + optimizer step in DDP.

Definition. All-reduce takes one array per device, applies an associative reduction (sum, for gradients) element-wise across all devices, and leaves the identical reduced result on every device. Input: N arrays. Output: N copies of their sum.

The naive implementation — every device sends its array to one root, which sums and broadcasts back — is a disaster at scale: the root's link carries (N−1)·n bytes and becomes a bottleneck. The elegant result is that all-reduce decomposes exactly into two bandwidth-optimal halves:

all-reduce = reduce-scatter + all-gather

Remember this
All-reduce is reduce-scatter, then all-gather.
This one identity unlocks half the field. It's why ZeRO can split DDP's all-reduce in two and shard the middle; why ring all-reduce runs in two phases; why FSDP's bandwidth equals DDP's. Internalize it and the rest cascades.

Watch it run. Four GPUs in a logical ring, each starting with a chunked array. Phase 1 (reduce-scatter) circulates and accumulates partial sums until each GPU owns one fully-reduced chunk. Phase 2 (all-gather) circulates those final chunks until everyone has the complete result.

RING ALL-REDUCE · N=4 · chunked array
IDLE REDUCE-SCATTER ALL-GATHER DONE
own un-reduced chunk partial sum (in flight) fully reduced chunk stale — will be overwritten
Ready. 4 GPUs, array split into 4 chunks.

Phase 1 — Reduce-Scatter, step by step

The array on each GPU is split into N chunks. The ring runs for N−1 steps. On each step, every GPU simultaneously sends one chunk to its clockwise neighbor and receives one from its counter-clockwise neighbor, adding the incoming chunk into its local copy. The chunks are offset so that after N−1 steps, GPU i holds the complete sum of chunk i — and crucially, the work is perfectly balanced: every link carried exactly (N−1)/N · n bytes.

Phase 2 — All-Gather, step by step

Now each GPU owns one finished chunk but is missing the other N−1. The same ring runs N−1 more steps, but this time incoming chunks overwrite (no addition needed — they're already final). After N−1 steps every GPU has every finished chunk: the full reduced array. Same (N−1)/N · n bytes per link.

Why the ✕ slots can be discarded — and why no extra memory is needed Staleness happens progressively, not all at once. In the animation, a slot flips to the moment a GPU forwards it for the last time — because the ring schedule guarantees that chunk index is never sent from that GPU again, and the live, still-accumulating copy now lives on the neighbor. So one slot per GPU goes stale at each reduce-scatter step, and you can watch the working set shrink as the ring rotates. By the phase boundary, each GPU holds exactly one fully-reduced chunk and N−1 stale slots. All-gather then overwrites the ✕ slots in place with finished chunks arriving from neighbors. The precise rule for staleness is "forwarded for the last time it will ever be needed" — in the ring algorithm that coincides with "forwarded at all," but in other algorithms (e.g. recursive-halving) a buffer can be reused across rounds, so forwarding alone wouldn't imply staleness. The payoff: ring all-reduce runs in a single n-sized buffer per rank for the whole operation — bandwidth-optimal and memory-optimal, with "discard" happening implicitly as overwrites.
Connecting to what you know In PyTorch DDP, this fires automatically. Gradients are bucketed (typically 25 MB) and an async all-reduce launches on each bucket as soon as its gradients are ready during backprop — so communication overlaps with the rest of the backward pass. That overlap is the single most important performance property of DDP, and it's why bucket size is a tuning knob. The collective itself is exactly the ring above.

Two nested decompositions: buckets vs ring chunks

A common point of confusion: the DDP bucket and the ring chunk are not the same thing — they live at different layers. A bucket is one whole all-reduce; the ring chunking happens inside it. The hierarchy is nested four levels deep:

GRANULARITY HIERARCHY · backward → bucket → ring chunk → pipeline slot
Ready. Each layer subdivides the one above it.

The two boundaries are independent: bucketing is PyTorch's knob for overlap with compute (start communicating before backward finishes); ring chunking is NCCL's mechanism for bandwidth optimality within a single collective. A 25 MB bucket is never "sent on a link" as a unit — it becomes one all-reduce, and that payload is diced into N chunks of ~25/N MB that rotate around the ring for 2(N−1) steps. Different buckets are separate all-reduces, pipelined against the ongoing backward pass.

Why 25 MB — and why not scale it with N?

A natural guess is that bucket size should grow with the DDP degree N to equalize per-link transfer. But per-link transfer is already N-invariant: a bucket's all-reduce moves 2·(N−1)/N·S bytes per link, which asymptotes to 2S and barely changes with N. The ring already equalized the per-link load across all N — there's nothing to compensate for by scaling S.

What actually sets the 25 MB default is the latency-vs-overlap tradeoff, not N. Too small → you pay the fixed α (launch) cost too often and the per-step chunks S/N are too tiny to reach peak bandwidth (latency-dominated regime). Too large → you wait longer in backprop before a bucket fills, shrinking the compute/comm overlap window. 25 MB is the empirical sweet spot for that balance; it's tunable via bucket_cap_mb.

Where the N-intuition is right: the quantity that degrades with N is the per-step chunk S/N, which shrinks as N grows and can push each ring step into the latency-bound regime. So large clusters do care about N — but the fix is usually to switch algorithms (tree/hierarchical all-reduce, with a log N latency term instead of 2(N−1)) or use multiple NCCL channels, rather than inflating the bucket. The trap: assuming per-link volume grows with N. It doesn't — the ring made it constant; only the per-step chunk shrinks.

SECTION 04

Reduce-Scatter + All-Gather as standalone tools — ZeRO & FSDP

The two halves of all-reduce are also powerful on their own. Splitting them apart is the entire idea behind sharded data parallelism, the technique that lets a model far larger than one GPU's memory still train.

In plain DP, every GPU holds a full copy of parameters, gradients, and optimizer state (for Adam in mixed precision that's the largest consumer of memory — fp32 master weights plus two moment estimates, the 16Φ term below). That replication is pure waste. ZeRO (DeepSpeed) and FSDP (PyTorch) shard those tensors across the data-parallel group so each GPU stores only 1/N of them, then reconstruct what's needed just in time.

Terminology check — reduce-scatter vs scatter-reduce These live at different levels. reduce-scatter is the name of the collective primitive (MPI_Reduce_scatter, ncclReduceScatter) — reduce across ranks, scatter so each keeps one slice. scatter-reduce is the colloquial name (from Baidu's 2017 ring all-reduce post) for phase 1 of the ring all-reduce algorithm. They describe the same data movement, but when naming the standalone collective the standard term is reduce-scatter — that's what this report uses.
FSDP ≡ which ZeRO? When people say "FSDP" without qualification they mean the ZeRO-3 equivalent (full sharding of parameters, gradients, and optimizer states). PyTorch FSDP exposes a sharding_strategy knob: FULL_SHARD = ZeRO-3, while SHARD_GRAD_OP behaves like ZeRO-2 (shards gradients + optimizer state, keeps parameters replicated). So "ZeRO/FSDP" as synonyms is shorthand for ZeRO-3 ↔ FSDP FULL_SHARD.

All-Gather — reconstruct the layer right before you need it

Before a layer's forward pass, each GPU holds only its shard of that layer's weights. An all-gather collects all shards so every GPU briefly holds the full layer, runs the compute, then frees the gathered weights immediately. Watch the shards assemble:

ALL-GATHER · 4 GPUs each holding 1 shard
this GPU's own shard shard received from peer not yet present
Ready. Each GPU owns shard k; after gather all hold all 4.

Reduce-Scatter — the gradient mirror image

After the backward pass each GPU has computed gradients for the full layer, but only needs to keep the slice matching its parameter shard. Reduce-scatter sums gradients across GPUs and leaves each GPU with only its slice in one fused operation — strictly cheaper than an all-reduce followed by a discard. Watch chunks accumulate and land:

REDUCE-SCATTER · sum across GPUs, keep one slice each
local contribution partial sum in flight final reduced slice (kept)
Ready. Each GPU will keep only chunk k of the summed gradient.

The three stages, precisely

The stages differ by what is sharded — and that, in turn, dictates which collective replaces DDP's all-reduce and how much gradient memory each GPU must hold. This is the distinction worth getting exactly right:

ZeRO-1 — shard optimizer state only

Only the optimizer state is partitioned. Each GPU still computes and holds the full gradient, but it only needs the summed gradient for its own shard to run the optimizer step — so the gradient synchronization is a reduce-scatter, not an all-reduce. (An all-reduce would broadcast every other shard's summed gradient back to every GPU, which ZeRO-1 would immediately throw away — pure waste.) Each GPU then updates its 1/N shard of optimizer state and parameters, and an all-gather republishes the updated weights for the next forward pass. The full gradient is still materialized in memory only because gradients aren't sharded and the backward pass writes the whole tensor before the single end-of-backward reduce-scatter consumes it — a memory fact, not a communication one. The saving is entirely in optimizer state (the 12Φ → 12Φ/N term, the biggest single consumer for Adam). Communication volume equals DDP (reduce-scatter + all-gather = one all-reduce's worth, by the decomposition identity).

ZeRO-2 — additionally shard gradients

The gradient synchronization is the same reduce-scatter as ZeRO-1 — the collective does not change between the two stages. What changes is purely the memory axis: ZeRO-2 shards the gradient itself, so each GPU keeps only the 1/N slice matching its optimizer/parameter shard and frees the rest as it goes. This is the correction to the common mental model — the ZeRO-1→ZeRO-2 difference is not "when the reduce-scatter fires" or "all-reduce vs reduce-scatter," but simply whether the gradient is sharded in memory at all. ZeRO-1 materializes the full gradient tensor and reduce-scatters it once at end-of-backward; ZeRO-2 buckets the reduce-scatter so each slice is reduced and the rest released incrementally during the backward pass, never holding the whole gradient at once. That cuts the gradient-memory term ZeRO-1 leaves untouched (the 2Φ → 2Φ/N term), at no extra communication cost — volume is still exactly DDP's.

ZeRO-3 (= FSDP FULL_SHARD) — additionally shard parameters

Parameters are now sharded too, so no GPU ever holds the whole model. Weights for a layer are all-gathered just before they're needed and discarded right after. Crucially this happens twice per iteration — once in forward, once again in backward — because the gathered weights are freed after the forward pass and must be re-collected to compute gradients. So a ZeRO-3 step issues three collectives: two all-gathers (weights, fwd + bwd) and one reduce-scatter (gradients), which works out to roughly 1.5× the communication volume of DDP in exchange for memory that now scales linearly with N. Here is the per-iteration timeline:

ZeRO-3 / FSDP · collectives issued per training iteration
all-gather (weights) reduce-scatter (grads) compute
Ready. Watch the 2× all-gather + 1× reduce-scatter pattern.
The identity that ties it together Because all-reduce = reduce-scatter + all-gather, ZeRO-1 and ZeRO-2 move the same total bytes as DDP — they only change where the gradient lives and how much memory it costs. ZeRO-3 adds one extra all-gather of parameters (the forward one; the backward all-gather replaces work DDP didn't have to do because DDP kept weights resident), landing at ~1.5× DDP's volume. The through-line: memory savings are large and monotonic across stages; communication is free through stage 2 and modestly more expensive at stage 3.
StageShardedFull gradient in memory?Collective replacing DDP all-reduceComm vs DDP
ZeRO-1Optimizer stateYesreduce-scatter (grads) + all-gather (params)
ZeRO-2+ GradientsNo — freed per bucketreduce-scatter
ZeRO-3 = FSDP+ ParametersNo2× all-gather + reduce-scatter~1.5×
Remember this
ZeRO shards memory, not math.
Each stage stores less per GPU (optimizer state → +gradients → +parameters) while computing the exact same thing. Stages 1–2 are free in bandwidth; stage 3 (= FSDP) costs ~1.5× DDP for an extra parameter all-gather. The win is memory; the math never changes.
Further reading For an animated walk-through of all-gather specifically, this video is a clear companion to the interactive figure above: youtube.com/watch?v=JFTvY7siOtk. (It shows each rank's shard propagating around the ring until every rank holds the full tensor — the same N−1 step structure animated here.)
SECTION 05

All-Reduce inside tensor parallelism

Same collective as data parallelism, completely different character — high-frequency, latency-critical, and pinned to the fastest links in the box.

Tensor parallelism (Megatron-LM style) splits the weight matrices within a layer across GPUs. Unlike FSDP — which all-gathers shards back to full weights and then computes — TP keeps the weights sharded through the computation itself. That single difference is the source of everything below: because each GPU multiplies only its slice, it produces a partial result, and the shard axis you chose determines exactly what reduction is needed to make it whole.

The transformer MLP block, as matrix algebra

A transformer's feed-forward block is two matmuls with a nonlinearity between: Z = GeLU(X·A)·B, where X is [s × h] (sequence × hidden), A is [h × 4h] (expand), and B is [4h × h] (contract). Megatron shards these two matmuls with conjugate axes — A by columns, B by rows — and that pairing is the whole trick. Watch the block shapes:

MEGATRON MLP · column-parallel A · row-parallel B · matrix blocks
Ready. Two GPUs; follow the block shapes through both matmuls.

Why this specific pairing? Column-parallel A splits the 4h (output) dimension, so GPU i holds columns A_i and computes Y_i = X·A_i — a full-height, narrow slice of the hidden activations. GeLU is element-wise, so it runs locally on each slice with no communication — this only works because the split is along the output dimension, keeping each activation column intact. Then row-parallel B splits the 4h (input) dimension to match Y's split: GPU i computes Y_i·B_i, a full-size but partial output. The true result is the sum Z₁+Z₂, so exactly one all-reduce closes the block. The conjugate split means the only communication in the entire block is that single reduction.

The forward/backward collective asymmetry: what f and g are

Megatron labels two spots in the dataflow where communication might be needed — two "gates" that bracket the parallel region of a block. f sits at the input (where X enters); g sits at the output (where the partial Zs come out). They are not different operations to memorize — each is simply an all-reduce in one direction and a no-op in the other:

  • g (output gate): forward = all-reduce (the Z₁+Z₂ sum you just saw); backward = no-op (the gradient passes straight through to each GPU).
  • f (input gate): forward = no-op (X is already replicated on every GPU, so nothing happens); backward = all-reduce.

So one forward+backward of a block costs exactly one all-reduce forward (at g) and one all-reduce backward (at f) — two total, at opposite ends.

Why does f need an all-reduce in backward, when forward needed nothing there? Because X is one tensor used by both GPUs in forward — each multiplies the same X by its own weight shard. When gradients flow back, each GPU computes a different partial gradient of X (one from each path X fed into). By the chain rule, the true gradient of X is the sum of those partials — and the layers before this block need that complete sum to continue backprop. Summing partials across GPUs is an all-reduce; that is f's backward job.

f & g GATES · all-reduce in one direction, no-op in the other
Ready. Toggle direction to see which gate becomes an all-reduce.
The duality to remember f and g are mirror images of one rule: replicated forward ⟺ all-reduce backward, and all-reduce forward ⟺ replicated backward. A tensor copied to all GPUs in forward (like X) has a summed gradient in backward (that's f); a tensor summed in forward (like Z) just passes its gradient through in backward (that's g). They're the same conjugate pair seen from the input vs. output side. Megatron names them only because manual TP requires inserting these all-reduces by hand — a compiler like XLA/GSPMD derives them for you, which is why the JAX version has no explicit f/g.
Remember this
Replicated one way ⟺ all-reduced the other.
If a tensor is copied across GPUs in the forward pass, its gradient must be summed (all-reduced) in the backward pass — and vice versa. That single symmetry is all f and g encode: one gate per end of the block, each a no-op going one way and an all-reduce going the other.

Attention is even more natural

Self-attention shards along an axis the architecture hands you for free: attention heads. Each head computes softmax(QKᵀ/√d)V completely independently, so Megatron places whole heads on different GPUs — GPU i owns a subset of heads and computes them end to end with no cross-GPU communication inside the attention score/softmax/value path. The per-head outputs are concatenated along the hidden dimension, then the output projection W_O is row-parallel (matching the head split), producing partial outputs that — same as the MLP — need a single all-reduce. So attention is column-parallel on QKV (by heads) and row-parallel on the output projection: structurally identical communication to the MLP.

ATTENTION · heads partitioned across GPUs · QKV → output proj
Ready. 4 heads split across 2 GPUs.
The systems consequence Because these all-reduces are tiny, frequent, and block compute, their cost is α-dominated, not β-dominated. You cannot tolerate inter-node latency for them. This is why tensor parallelism is almost always confined to within a single node (8 GPUs over NVLink/NVSwitch, ~hundreds of GB/s and sub-microsecond hops), while data and pipeline parallelism span the slower inter-node InfiniBand/Ethernet fabric. TP degree rarely exceeds the GPUs-per-node count for exactly this reason.

Why TP is latency-bound: the chain of reasoning

Unpacking "small, frequent, latency-bound," because the three are causally linked — and the real discriminator is the last one, not the first two.

Small. A TP all-reduce sums an activation tensor of shape [b × s × h] — one block's output for the current microbatch — not the whole gradient. The DP gradient all-reduce, by contrast, covers every parameter in the model. So per individual collective call, the TP payload is far smaller, small enough that it doesn't saturate the link: the bandwidth term is tiny.

Frequent. Each transformer layer has two TP blocks (attention + MLP), each ending in an all-reduce, plus the conjugate one in backward — so ≈ 4 × num_layers all-reduces per step, i.e. hundreds. DP fires its gradient reduction effectively once per step (bucketed into a few launches).

Blocking — and this is the one that matters. The all-reduce output is the input to the next operation (the residual add, the next sublayer). There is no independent work to run while it completes, so compute stalls on it. The cost becomes (hundreds) × α — the fixed per-hop latency paid hundreds of times, with nothing hiding it. That is the latency-bound regime, and the only lever is shrinking α → NVLink.

The discriminator is overlappability, not frequency A natural objection: FSDP also communicates per layer (all-gather weights before each layer, reduce-scatter grads after) — so why isn't it latency-bound? Because FSDP's per-layer comm is overlappable: while layer k computes on its already-gathered weights, FSDP prefetches layer k+1's all-gather concurrently. There's independent work to hide the latency behind, so it stays bandwidth-bound despite being frequent. TP's comm sits on a true data dependency inside the block — the next op needs the reduced result — so there's nothing to overlap. Frequency is shared between FSDP and TP; overlappability is what splits them. Message size is a contributing factor, but the causal root of the regime is whether compute can proceed while the bytes move.
Remember this
FSDP's comm hides; TP's comm blocks.
What decides latency- vs bandwidth-bound isn't how often you communicate — it's whether compute can run alongside it. FSDP prefetches the next layer while computing the current one (hidden → bandwidth-bound, tolerates slow links). TP's all-reduce feeds the very next op, so the GPU waits (exposed → latency-bound, needs NVLink). Overlappable beats infrequent.
COMPUTE TIMELINE · overlappable vs blocking communication
compute comm — hidden (overlapped) comm — exposed (blocking) GPU idle (stall)
Ready. Three regimes on the same compute timeline.

Is TP architecture-specific, or generic? Both — at different layers

This is the crux of a confusion worth resolving cleanly, because Megatron reads as transformer-specific while JAX/GSPMD lets you annotate any dimension as sharded. Both pictures are correct; they live at different levels.

The mechanism is fully generic. At the linear-algebra level, a matmul can be sharded along essentially any dimension of any operand, and each choice has a determinate rule for the communication needed to keep the result correct. Sharding the contraction (inner) dimension of a matmul forces a reduction afterward → all-reduce (or reduce-scatter). Sharding a non-contracted (outer) dimension keeps outputs sharded with no reduction, but may need an all-gather later when a downstream op wants the full tensor. This is exactly what XLA/GSPMD automates: you annotate the "big" tensors with mesh axes, and the compiler derives and inserts the right collectives. Nothing here is transformer-specific — it's the algebra of distributing einsums.

The Megatron scheme is a specific, expert-chosen set of those annotations. Its value isn't "TP for transformers exists" — it's the particular conjugate pairing (column-parallel then row-parallel) that makes an entire block cost just one all-reduce per direction, with the nonlinearity on the local side. In raw Megatron you hand-write those collectives; in JAX you'd reach the same result by annotating the same dimensions and letting GSPMD insert them. So: any axis is legal; only specific axes are cheap. The skill is choosing shardings whose collectives are few, small, and overlappable — what Megatron hand-derived, and what GSPMD's propagation tries to discover or you guide with annotations.

SAME MATMUL · which axis you shard determines the collective
Ready. Y = X·W, with X [m×k] and W [k×n]. Pick an axis to shard.
The clean contrast with FSDP The FSDP comparison is the cleanest discriminator. FSDP is shard-axis-agnostic because it always all-gathers shards back to full weights before computing — the math always runs on the unsharded tensor, so the shard axis only decides how storage is sliced, never what's computed. TP is shard-axis-sensitive because the tensor stays sharded through the matmul — so the axis determines what partial result comes out and therefore which collective repairs it. In one sentence: FSDP shards storage; TP shards computation.
SECTION 06

All-to-All for Mixture-of-Experts

The collective that defines MoE training cost. Every device sends a personalized chunk to every other device — the most communication-intensive primitive here.

In an MoE layer, a lightweight router assigns each token to one or a few experts (independent FFNs). With expert parallelism, experts live on different GPUs, so a token computed on GPU 0 may be routed to an expert on GPU 3. The fix is all-to-all: unlike all-gather (where everyone sends the same data to everyone), in all-to-all each device sends a different chunk to each other device. It is a full transpose of data across the group.

MoE uses two all-to-alls per layer:

  1. Dispatch — after routing, send each token to the GPU hosting its assigned expert.
  2. Combine — after the experts compute, send each result back to the GPU where its token originated, to continue the forward pass.
ALL-TO-ALL DISPATCH · tokens routed to expert GPUs
token → expert 0/1 token → expert 2/3
Ready. Each GPU's tokens are color-coded by destination expert.
Why MoE is hard — the key points
  • Load imbalance. Routing is data-dependent, so some experts get far more tokens than others. The all-to-all is bottlenecked by the busiest GPU. This is why MoE training uses an auxiliary load-balancing loss and a capacity factor that drops or pads tokens to keep buffers fixed-size.
  • It doesn't overlap easily. The combine all-to-all sits squarely on the critical path between the experts and the rest of the layer. Hiding it requires careful kernel-level pipelining (e.g. overlapping dispatch of one chunk with compute of another).
  • Bandwidth scaling. All-to-all moves (N−1)/N · n bytes per device like a ring, but with N−1 distinct destinations it is far more sensitive to topology and tends to dominate the comm budget in large MoE models.
Remember this
All-to-all is a transpose.
All-gather sends everyone the same data; all-to-all sends everyone different data — each device hands every other device a personalized chunk. That's a distributed matrix transpose, and it's why MoE routing (tokens → experts → back) costs two of them per layer and dominates the communication budget.
SECTION 07

Point-to-point & pipeline parallelism

Not strictly a collective, but the backbone of pipeline parallelism and worth understanding as the contrast case — the bubble problem and how scheduling hides it.

Pipeline parallelism splits the layer stack into sequential stages, one group of layers per device. A microbatch flows stage 0 → 1 → 2 → 3 in the forward pass (each handoff a single send/recv of activations), then gradients flow back the other way. There's no group reduction here — just paired transfers between adjacent stages.

The challenge is the pipeline bubble: while stage 0 works on the first microbatch, stages 1–3 sit idle, and similarly at drain. The fix is to split the batch into many microbatches and schedule them so stages stay busy. Watch a naive fill/drain vs. the interleaved schedule:

PIPELINE SCHEDULE · 4 stages · microbatch flow & the bubble
forward pass backward pass bubble (idle)
Ready. 4 stages × 4 microbatches.
The bubble fraction formula For a naive (GPipe) schedule with p stages and m microbatches, the fraction of time wasted in the bubble is (p−1)/(m + p−1). The takeaway: increase m relative to p to shrink the bubble. Advanced schedules — 1F1B (one-forward-one-backward) and interleaved 1F1B (Megatron) — keep the same bubble fraction formula but slash peak activation memory by running backward passes earlier, which is usually the real constraint.
Remember this
More microbatches, smaller bubble.
Pipeline idle time is (p−1)/(m+p−1) — fixed fill/drain cost amortized over m microbatches. Feed more microbatches and the bubble shrinks toward zero. 1F1B doesn't change the fraction; it cuts activation memory so you can afford the larger m that does.
SECTION 08

Composing TP × FSDP on four GPUs

The two techniques run at once in real jobs. Here they are side by side on a 2×2 mesh — TP degree 2 crossed with FSDP degree 2 — one row per GPU, steppable so each moment can be inspected.

The mesh is 2 × 2 = 4 GPUs. The horizontal axis is the TP group (two ranks that split each weight matrix and must all-reduce); the vertical axis is the FSDP group (two replicas that shard parameters and all-gather them on demand). Each of the four rows below is one GPU's local timeline through a single MLP block. Step through and watch the central contrast: the FSDP all-gather overlaps compute (it prefetches), while the TP all-reduce blocks it (the next op needs the result).

TP×FSDP MESH · 4 GPUs · one MLP block, step by step
0 · IDLE 1 · AG 2 · FWD 3 · AR 4 · RS
compute FSDP all-gather (overlaps) FSDP reduce-scatter (overlaps) TP all-reduce (BLOCKS) stall / idle
Ready. 4 GPUs as a 2×2 TP×FSDP mesh. Step through one MLP block.

Reading the mesh: the two GPUs in a row (same FSDP replica) form a TP pair — they hold complementary column shards of the weights and will all-reduce their partial outputs. The two GPUs in a column (same TP rank) form an FSDP pair — they each hold half of that shard's parameters on disk and all-gather to reconstruct the full shard before compute. So every weight is sharded twice: once by TP (which columns) and once by FSDP (which slice of those columns is resident).

Remember this
FSDP shards across replicas; TP shards within the math.
They're orthogonal. FSDP gathers a full shard, computes, and frees it — comm hidden by prefetch. TP keeps its slice sharded through the matmul and pays a blocking all-reduce to stitch partial outputs. One axis trades memory for overlappable bandwidth; the other trades latency for never materializing the full matrix. Stacked, they tile the GPU mesh.
SECTION 09

Putting it together: topology & overlap

Real frontier runs combine all of the above — "3D" or "4D parallelism." The art is mapping each collective onto the right slice of the network and overlapping it with compute.

A frontier training job composes these strategies into nested groups. A representative layout on a cluster of 8-GPU nodes:

  • Tensor parallel (×8) — within a node, over NVLink. Latency-critical all-reduces stay on the fastest fabric.
  • Pipeline parallel (×several) — across a handful of nodes. Only thin activation tensors cross these links (cheap P2P).
  • Data / sharded-data parallel (×many) — across the remaining nodes over InfiniBand. Bandwidth-heavy reduce-scatter/all-gather, but tolerant of latency and overlappable with backprop.
  • Expert parallel — if MoE, all-to-all over a dimension chosen to balance against the others.

The placement principle is consistent and worth stating crisply: match each collective's sensitivity to the link's strength. Put α-sensitive, high-frequency traffic (TP all-reduce) on the lowest-latency fabric; put β-heavy, latency-tolerant traffic (DP gradient reduction) on links where you have bandwidth to spare and computation to hide it behind.

OVERLAP · communication hidden behind computation
Ready. Compare exposed vs overlapped communication.
The one-sentence thesis A well-tuned training job is one where the GPUs are almost never waiting on the network — every byte that must move is moving while the GPUs compute something else. Every collective choice, topology decision, and bucket-size knob serves that single goal.
SECTION 10

Self-check

Questions that map directly onto the material above. Worth answering before expanding each one.

Why is ring all-reduce considered bandwidth-optimal, and what's its weakness?

Each link carries exactly 2·(N−1)/N · n bytes per device — independent of N — which is provably the minimum any all-reduce must move, since every device's data must reach every other and be summed. No link is a hotspot. Its weakness is the latency term: 2(N−1) sequential steps means 2(N−1) α-costs, so at large N the fixed per-hop latency dominates and tree/hierarchical algorithms (log N steps) become preferable for small messages.

A model fits on one GPU but its optimizer state doesn't. Minimum change?

ZeRO Stage 1 — shard only the optimizer state across the data-parallel group. For Adam that's the biggest memory consumer (params + 2 moments in fp32), so sharding it alone gives a large saving with zero extra communication: the gradient all-reduce is unchanged; each GPU just updates its own shard of the optimizer state and the params are kept in sync. You only escalate to Stage 2/3 (FSDP) when you also need to shard gradients and parameters.

Why is tensor parallelism kept within a node but data parallelism spans nodes?

TP issues an all-reduce on the critical path per block — hundreds per step — and the key property is that they cannot overlap with compute: the reduced output is the input to the next op, so the GPU stalls until it returns. Cost becomes (hundreds) × α, i.e. latency-bound, so it demands NVLink's sub-microsecond hops. DP's gradient reduction is overlappable with the backward pass (hidden behind compute), so it's bandwidth-bound and tolerates inter-node InfiniBand latency. The discriminator is overlappability, not frequency: FSDP is also per-layer/frequent, but it prefetches each all-gather under the previous layer's compute, so it stays bandwidth-bound too. Place each collective on the fabric matching its bottleneck — and what sets the bottleneck is whether compute can proceed while the bytes move.

What makes all-to-all in MoE harder to optimize than all-reduce in DP?

Three things. (1) Data-dependent load imbalance — routing sends uneven token counts to experts, so the collective stalls on the busiest GPU; mitigated with a load-balancing aux loss and a capacity factor. (2) Poor overlap — the combine all-to-all sits between the experts and the rest of the layer on the critical path, unlike DP's all-reduce which overlaps with backprop. (3) N−1 distinct destinations make it topology-sensitive, so it often dominates the communication budget.

Derive the GPipe bubble fraction and explain how to reduce it.

With p stages, the pipeline takes p−1 steps to fill and p−1 to drain; with m microbatches the useful work spans m steps. Bubble fraction = idle/(idle+useful) = (p−1)/(m+p−1). Reduce it by raising m relative to p (more microbatches). 1F1B scheduling doesn't change this fraction but cuts activation memory by interleaving backward passes early, which is usually the binding constraint that lets you raise m in the first place.

Why do ZeRO-1 and ZeRO-2 cost the same bandwidth as DDP, while ZeRO-3 costs ~1.5×?

By the identity all-reduce = reduce-scatter + all-gather. DDP does one all-reduce of gradients. ZeRO-1/2 just reorganize that: the reduction is still one reduce-scatter + one all-gather worth of traffic, only now each GPU keeps a shard instead of the full result — same bytes moved, less memory held. ZeRO-3 additionally shards parameters, so it must all-gather weights twice per iteration (once in forward, once in backward, since the gathered weights are freed after forward) on top of the gradient reduce-scatter. That extra parameter all-gather is what pushes it to roughly 1.5× DDP's volume — the price of also sharding parameters.

Precisely, what changes between ZeRO-1 and ZeRO-2?

What's sharded. ZeRO-1 shards only optimizer state; each GPU still computes and holds the full gradient before synchronizing, so gradient memory equals DDP's. ZeRO-2 additionally shards gradients via reduce-scatter, so each GPU keeps only its slice and frees the rest (bucketed during the backward pass). A common misconception is that the difference is when the reduce-scatter fires — it's actually whether gradients are sharded at all. The practical consequence: ZeRO-2 cuts the gradient-memory term that ZeRO-1 leaves untouched, at no extra communication cost.

Is it "reduce-scatter" or "scatter-reduce"? And is FSDP the same as ZeRO?

reduce-scatter is the standard name of the collective primitive (MPI_Reduce_scatter, ncclReduceScatter). "scatter-reduce" is the informal name for phase 1 of the ring all-reduce algorithm (from Baidu's 2017 post) — same data movement, different level of abstraction. Use "reduce-scatter" when naming the collective. As for FSDP: unqualified, it means the ZeRO-3 equivalent (full sharding). PyTorch's FULL_SHARD = ZeRO-3 and SHARD_GRAD_OP = ZeRO-2, so "ZeRO/FSDP" as synonyms specifically means ZeRO-3 ↔ FSDP FULL_SHARD.

Should DDP's bucket size scale with the number of GPUs N to equalize per-link load?

No — and the reason is the trap. Per-link volume for a bucket's all-reduce is 2·(N−1)/N·S, which asymptotes to 2S and is essentially N-invariant: the ring already equalized it across all N, so there's nothing to compensate for. Bucket size (~25 MB default) is set by the latency-vs-overlap tradeoff instead — small buckets pay α-launch cost too often and underfill bandwidth; large buckets delay the first all-reduce and shrink the compute/comm overlap window. Where N does matter is the per-step chunk S/N, which shrinks with N and can push ring steps into the latency-bound regime — but that's handled by switching to tree/hierarchical algorithms (log N latency term) or more NCCL channels, not by inflating the bucket.

Distinguish a DDP bucket from a ring chunk.

Different layers. A bucket (~25 MB) is a PyTorch-level grouping of gradient tensors; each bucket becomes one complete all-reduce, fired as soon as it fills during backprop — its purpose is overlap with computation. A ring chunk is NCCL-internal: inside one bucket's all-reduce, the payload is split into N chunks (~S/N each) that rotate around the ring for 2(N−1) steps — its purpose is bandwidth optimality within the collective. The boundaries are independent: a bucket is never sent as a unit on a link; it's diced into rotating chunks. (NCCL adds a third level — pipeline slots within each chunk — to keep the wire saturated.)