ZeRO and FSDP: Model Sharding
*Part 2 of 4 — Distributed Training series. Start with the overview.*
1. Distributed Data Parallel: How It Actually Works — Scale throughput when the model fits on one GPU. 2. ZeRO and FSDP: Model Sharding (this post) — Fit models that don't fit, by sharding weights, gradients, and optimizer state. 3. Tensor Parallelism and Sequence Parallelism — Shard inside each layer; wins when interconnect is slow or a single layer is oversized. 4. Pipeline Parallelism: How It Actually Works — Shard across layers; the axis that cheaply spans nodes.
Training a model with Adam costs far more memory than the model itself. A 4B model in BF16 is 8 GB of weights, but training it requires 48 GB total: the 8 GB of weights, another 8 GB of gradients, and 32 GB of Adam's two FP32 running averages. Five sixths of the footprint isn't the model — it's the state needed to update it.
The core idea behind ZeRO (Zero Redundancy Optimizer, Rajbhandari et al., 2020) is to stop duplicating that state. If every GPU ends up with the same optimizer state after each step, there is no reason for every GPU to store the full copy. Each GPU can hold 1/N of the state and reconstruct what it needs on demand. ZeRO eliminates the redundant copies in stages, trading communication for memory. Stage 1 shards optimizer states. Stage 2 adds gradients. Stage 3 shards the weights themselves, so no single GPU holds a complete model. PyTorch's implementation is called Fully Sharded Data Parallel (FSDP).
To see why this matters, consider standard data parallelism (→ DDP). DDP runs one copy of the model per GPU, splits the data across them, and synchronizes gradients with AllReduce at each step. It scales throughput nearly linearly, but does nothing about memory: every GPU still holds the full 48 GB independently. On a cluster of 8 GPUs, that's 384 GB allocated, of which 336 GB is pure duplication. The per-GPU footprint is the same whether you have 1 GPU or 64. FSDP exists to eliminate that duplication.
This post starts with where the memory goes, walks through each ZeRO stage, and then gets practical: how FSDP implements the ZeRO stages, what every configuration knob does, and the gotchas that silently degrade or crash training. The code snippets are from a real training script where we finetuned Qwen3-4B on AMI meeting transcripts from knkarthick/AMI with LoRA.
| | | |---|---| | Model | Qwen3-4B (BF16) | | Method | LoRA (r=16, alpha=32, all linear layers) | | Dataset | knkarthick/AMI (237 samples) | | GPUs | 2x RTX 4090 (PCIe, Vast.ai) | | Sharding strategy | FULL_SHARD (ZeRO-3) | | Effective batch size | 16 (per-GPU=1, grad accum=8, 2 GPUs) | | Sequence length | 512 | | Optimizer | AdamW (lr=2e-4, cosine decay to 2e-5) | | Epochs | 3 |
---
Where the Memory Goes
Full Adam finetuning of a model with P parameters in BF16 requires:
| Component | Memory | For 4B model | |---|---|---| | Parameters (BF16) | 2P bytes | 8 GB | | Gradients (BF16) | 2P bytes | 8 GB | | Adam m (FP32) | 4P bytes | 16 GB | | Adam v (FP32) | 4P bytes | 16 GB | | Total | 12P bytes | 48 GB |
Weights and gradients combined are one third of the total (8 + 8 = 16 GB). The optimizer alone is the other two thirds (32 GB). This ratio is what makes the redundancy problem so severe: the biggest memory consumer — optimizer state — is exactly the component that's identical across every GPU in a data-parallel setup.
In standard data parallelism, every GPU holds the full 48 GB independently. With 8 GPUs, that's 384 GB of memory allocated cluster-wide, of which 336 GB is pure duplication. Every GPU maintains its own copy of Adam's running mean and variance, its own full gradient tensor, and its own full set of weights. After AllReduce averages the gradients and every GPU applies the same optimizer update, the states are all identical again. The duplication serves no purpose. It's just the default.
---
ZeRO: Eliminating Redundancy in Stages
Zero Redundancy Optimizer or ZeRO doesn't change the math of training. The same gradients are computed, the same optimizer updates are applied, the same model results. What it changes is who holds what. Instead of every GPU maintaining a complete, redundant copy of everything, each GPU holds a 1/N slice and the group coordinates to reconstruct what's needed, when it's needed.
The paper introduced three cumulative stages, each sharding an additional component:
ZeRO-1: Shard Optimizer States
The optimizer states are the largest single memory consumer: 8P bytes for Adam (m and v, both in FP32). In standard data parallelism, every GPU keeps its own full copy. After the optimizer step, all copies are identical.
ZeRO-1 partitions the optimizer states across N GPUs. Each GPU owns 1/N of the parameters and maintains Adam states only for that slice. After AllReduce averages the gradients (same as standard data parallelism), each GPU runs the optimizer update only on its assigned parameters, then AllGathers the updated weights so every GPU has the full model for the next forward pass.
Per-GPU memory drops from 12P to 4P + 8P/N. With 8 GPUs, that's 4P + P = 5P, down from 12P. For a 4B model: 20 GB instead of 48 GB. The weights and gradients are still fully replicated.
ZeRO-2: Shard Optimizer States + Gradients
Gradients are 2P bytes, and like optimizer states, they're identical across GPUs after reduction. ZeRO-2 eliminates this redundancy too: instead of AllReduce (which gives every GPU the full averaged gradient), it uses ReduceScatter (which gives each GPU only the 1/N slice corresponding to its assigned parameters). Each GPU receives just enough gradient to update its own optimizer shard. The rest is never materialized.
Per-GPU memory drops to 2P + (2P + 8P)/N. With 8 GPUs: 2P + 1.25P = 3.25P. For a 4B model: 13 GB. The weights are still fully replicated, but gradients and optimizer states are sharded.
The communication pattern is different from ZeRO-1. ZeRO-1 uses AllReduce on gradients (every GPU gets the full averaged gradient, cost 2(N-1)/N × gradient size), then each GPU updates only its optimizer shard and AllGathers the updated weights (another (N-1)/N × weight size). ZeRO-2 uses ReduceScatter on gradients (cost (N-1)/N × gradient size), skipping the full gradient materialization entirely, followed by the same AllGather of updated weights. ZeRO-2 moves strictly less data than ZeRO-1 under this scheme — roughly one gradient-tensor's worth less, because ReduceScatter is half an AllReduce — on top of using less memory.
ZeRO-3: Shard Everything
ZeRO-3 takes the final step: the weights themselves are sharded. No GPU holds a complete copy of the model. Each GPU stores 1/N of the parameters, 1/N of the gradients, and 1/N of the optimizer states.
Per-GPU memory drops to 12P/N. With 8 GPUs: 1.5P. For a 4B model: 6 GB. The entire 48 GB training footprint, distributed across the cluster, with each GPU holding only its fraction.
The cost is communication. In ZeRO-1 and ZeRO-2, every GPU has the full weights at all times, so forward and backward passes run locally. In ZeRO-3, the full weights don't exist on any single device. Before each layer's forward pass, the GPUs must AllGather that layer's complete parameters from everyone's shards. After forward, the reconstructed weights are discarded. During backward, the same AllGather happens again (the weights are needed to compute gradients), followed by a ReduceScatter to distribute the gradient shards.
This is the fundamental tradeoff: ZeRO-3 achieves the maximum memory reduction but introduces communication at every layer boundary in both the forward and backward pass.
The Full Picture
| Stage | What is sharded | Per-GPU memory | Communication pattern | |---|---|---|---| | ZeRO-0 (DDP) | nothing | 12P | AllReduce once per step | | ZeRO-1 | optimizer states | 4P + 8P/N | AllReduce + AllGather per step | | ZeRO-2 | optimizer states + gradients | 2P + 10P/N | ReduceScatter + AllGather per step | | ZeRO-3 | everything | 12P/N | AllGather + ReduceScatter per layer, per pass |
Memory drops with each stage. Communication increases with each stage. The right choice depends on how tight memory is and how fast the interconnect is.
---
FSDP: ZeRO-3 in PyTorch
FSDP is PyTorch's native implementation of the ZeRO family, with FULL_SHARD corresponding to ZeRO-3. It also exposes ZeRO-2 and ZeRO-0 through different sharding strategies:
| FSDP ShardingStrategy | ZeRO equivalent |
|---|---|
| NO_SHARD | ZeRO-0 (identical to DDP) |
| SHARDGRADOP | ZeRO-2 |
| FULL_SHARD | ZeRO-3 |
| HYBRID_SHARD | ZeRO-3 intra-node, DDP across nodes |
ZeRO-1 (shard only optimizer states) has no direct FSDP equivalent. It can be approximated, but rarely needs to be.
HYBRIDSHARD is worth calling out. In multi-node setups, cross-node bandwidth (InfiniBand, typically 200-400 Gb/s) is much slower than intra-node bandwidth (NVLink, 600-900 GB/s). HYBRIDSHARD runs FULL_SHARD within each node for memory savings, but replicates the model across nodes like DDP, so cross-node communication drops to one AllReduce per step instead of per-layer AllGather/ReduceScatter across the entire cluster.
The training script defaults to FULL_SHARD and writes the strategy name to the CSV log, so switching strategies is a one-line change:
`python
SHARDINGSTRATEGY = ShardingStrategy.FULLSHARD
`
Everything from here on is about how to make ZeRO-3 work in practice through FSDP.
---
The FSDP Workflow at a Glance
Before diving into each piece, here is the end-to-end shape of an FSDP step:
1. Load. Model is materialized on CPU (full weights, one copy). 2. Wrap and shard. FSDP splits each layer's parameters into N equal shards across N ranks. Every rank now holds 1/N of every layer's weights, gradients, and optimizer state. The full unsharded tensors no longer exist anywhere. 3. Pre-forward (per layer). AllGather the shards of the layer about to run, reconstructing its full weights on every rank. 4. Forward (per layer). Run the layer. Free the gathered weights immediately; keep only the shard. 5. Pre-backward (per layer). AllGather that layer's weights again (they were freed after forward). 6. Backward (per layer). Compute gradients for the full layer, then ReduceScatter: sum grads across ranks and keep only this rank's 1/N slice. Free the gathered weights and the full gradient. 7. Optimizer step. Each rank updates only its own parameter shards using its local gradient and optimizer state slices. No collective needed.
At any instant the GPU holds: every layer's shard, plus the full weights of the *one* layer currently executing. That is the entire memory win.
`
Initial state (after wrap & shard, N=4 ranks, layers L1..LN):
Rank 0 Rank 1 Rank 2 Rank 3 +--------+ +--------+ +--------+ +--------+ | L1.a | | L1.b | | L1.c | | L1.d | each rank holds | L2.a | | L2.b | | L2.c | | L2.d | 1/N of every | ... | | ... | | ... | | ... | layer's params, | LN.a | | LN.b | | LN.c | | LN.d | grads, opt state +--------+ +--------+ +--------+ +--------+
FORWARD (for i = 1..N):
step 1: AllGather Li shards from all ranks Rank0:[Li.a] Rank1:[Li.b] Rank2:[Li.c] Rank3:[Li.d] │ ▼ every rank now has full Li = [Li.a | Li.b | Li.c | Li.d]
step 2: run forward(Li) on local activations step 3: free full Li → only the shard (Li.x) remains
GPU memory at step 2 (peak within forward): all shards (1/N of every layer) + one full layer (Li) + activations so far GPU memory at step 3 (steady state): all shards (1/N of every layer) + activations so far
BACKWARD (for i = N..1):
step 1: AllGather Li shards (weights were freed after forward) step 2: compute grads dL/dLi (full tensor, on every rank) step 3: ReduceScatter dL/dLi sum across ranks, then split: Rank0 keeps dLi.a, Rank1 keeps dLi.b, ... (each rank: 1/N) step 4: free full Li and full dL/dLi → shard + grad-shard remain
GPU memory at step 2 (peak within backward): all shards + one full layer (Li) + full gradient (dL/dLi) + remaining activations GPU memory at step 4 (steady state): all shards + grad shards for layers already processed + remaining activations
OPTIMIZER STEP:
each rank updates ONLY its own shards using its own grad + opt-state shards
no collective communication
`
The Communication Pattern: AllGather and ReduceScatter
The two collectives that make ZeRO-3 work are AllGather and ReduceScatter, woven through every layer of the forward and backward passes.
AllGather reconstructs a full parameter tensor from its N shards. Before each layer's forward pass, FSDP calls AllGather to assemble the complete weights. After the forward pass completes, the reconstructed weights are immediately discarded; only the shard is kept. Peak memory at any moment holds only the shards of all layers plus the full weights of whichever layer is currently executing.
ReduceScatter is the inverse. After each layer's backward pass computes gradients, FSDP runs ReduceScatter: each rank contributes its local gradient slice, NCCL sums them across ranks, and each rank receives only the portion of the result that corresponds to its parameter shard. No rank holds the full averaged gradient; each holds 1/N of it.
This is what makes ZeRO-3 sensitive to interconnect bandwidth. The per-layer AllGather/ReduceScatter cost scales with the number of layers and the parameter count per layer. On NVLink (900 GB/s bidirectional), these collectives can overlap substantially with computation. On PCIe (64 GB/s), they dominate the step time.
Why this matters: FSDP vs TP+SP on PCIe
FSDP and → TP+SP solve the same problem -- fitting a large model across multiple GPUs -- but they move fundamentally different things over the interconnect, and that difference is everything on commodity hardware.
FSDP shards weights but must reconstruct each full layer on every GPU before forward and backward. Per training step it moves roughly 3x the model size: AllGather in forward, AllGather in backward, ReduceScatter of grads in backward. For Qwen3-4B in BF16 that is ~24 GB of weight traffic per step.
TP+SP keeps each GPU's weight shard in place and only ever communicates *activations* — the [batch, seq, hidden] intermediate tensors produced inside each layer. At normal training shapes that comes out to 1-2 GB per step regardless of model size, an order of magnitude less than FSDP.
On NVLink (600-900 GB/s) FSDP's 24 GB is manageable and overlaps cleanly with compute. On PCIe (~25 GB/s effective on 2x 4090 with no NVLink) it becomes the bottleneck: weight traffic alone accounts for the majority of step time, and the AllGathers serialize because there is no second link to hide them behind. The result on our box is stark — FSDP runs at ~43s/step, TP+SP at ~7s/step, same model, same batch, same hardware.
The rule that falls out of this: if you do not have NVLink between your GPUs, → TP+SP is the correct choice for any model where parameters significantly outnumber activation volume, which is true of essentially every LLM at normal training batch sizes. FSDP is designed for the regime where the interconnect makes weight traffic free; PCIe is not that regime.
---
Loading the Model: CPU First, Then Shard
A subtle initialization detail that can double your peak GPU memory if you get it wrong.
The natural instinct is to load the model onto its GPU directly. With FSDP, that means every rank loads the full model onto its GPU *before* FSDP shards it. For Qwen3-4B in BF16, that's 8 GB per GPU just sitting there waiting to be sharded. On 2 GPUs, both briefly hold the full 8 GB before FSDP brings it down to ~4 GB each.
The fix is to load onto CPU and let FSDP handle the transfer:
`python
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
lowcpumem_usage=True,
)
`
Then when FSDP wraps the model with deviceid=torch.device(f"cuda:{localrank}"), it shards the CPU tensors first and moves each rank's shard to the GPU. Peak GPU memory during initialization stays at 1/N of the full model.
For a 4B model this is a nice-to-have. For a 70B model across 8 GPUs, it's the difference between initialization succeeding or every GPU OOMing before training starts.
---
autowrappolicy: Granularity of Sharding
If the wrap policy matches zero modules, FSDP does not warn. It silently puts the entire model into a single flat parameter unit. AllGather then tries to reconstruct all 4B parameters at once instead of one layer at a time, and you OOM with a suspiciously large allocation that doesn't change with sequence length. Activation checkpointing, which uses the same layer class to decide what to checkpoint, also matches nothing and silently does nothing. The symptom is an OOM that looks like FSDP isn't sharding at all, because in effect it isn't.
This is the most common silent failure mode in FSDP setup. It happens when the layer class in the wrap policy doesn't match the actual model architecture. With PEFT, it also happens when transformerautowrap_policy fails to recurse through the wrapper chain (PeftModel -> LoraModel -> Qwen3ForCausalLM -> ... -> Qwen3DecoderLayer).
To understand why, here's what the wrap policy controls. FSDP wraps the model in nested FSDP units. The outermost unit covers the whole model. Inner units cover individual modules. AllGather and ReduceScatter happen per unit, not per parameter:
Per-layer wrapping is almost always correct for transformers. The training script uses lambdaautowrappolicy, which runs a plain isinstance check and always recurses into every module (unlike transformerautowrappolicy, which can miss layers behind PEFT wrappers):
`python
from torch.distributed.fsdp.wrap import lambdaautowrap_policy
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
auto_wrap = functools.partial(
lambdaautowrap_policy,
lambda_fn=lambda module: isinstance(module, Qwen3DecoderLayer),
)
`
The layer class must match exactly. Qwen3 uses Qwen3DecoderLayer. Llama 3 uses LlamaDecoderLayer. Mistral uses MistralDecoderLayer. Verify before running:
`python
print(type(model.model.layers[0]))
`---
MixedPrecision: Three Dtypes, Three Roles
FSDP controls mixed precision independently from torch.autocast. Its MixedPrecision config sets the dtype for three separate purposes:
`python
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
`
param_dtype controls how parameter shards are stored and how AllGather reconstructs them. BF16 halves the AllGather communication volume compared to FP32 and halves the memory consumed by the reconstructed full-layer weights during forward and backward.
reducedtype controls the dtype used for ReduceScatter during the backward pass. Setting this to FP32 while paramdtype is BF16 means that gradient reduction happens in higher precision than the parameters themselves. This matters because gradient values are small and can lose significant information when represented in BF16, especially in deep networks where rounding errors accumulate across many layers. The FP32 reduction is converted back to BF16 for storage in each rank's gradient shard.
buffer_dtype controls non-parameter tensors: layer norm scales, biases, running statistics. These are typically small enough that the precision choice has minimal effect on memory, but keeping them in BF16 is consistent with the overall mixed-precision approach.
torch.autocast in the training loop handles activations and computation. MixedPrecision handles parameters and communication. Both are needed; neither replaces the other.
One gotcha when using both together: combining torch.autocast with non-reentrant activation checkpointing can cause a tensor-count mismatch. Autocast injects cast ops that differ between the original forward and the recomputation during backward. The training script avoids this by relying on FSDP's MixedPrecision policy for dtype management and not wrapping the forward pass in autocast:
`python
No torch.autocast here: FSDP's MixedPrecision policy handles dtypes.
outputs = model(inputids=inputids, labels=labels) loss = outputs.loss / GRAD_ACCUM loss.backward()`---
backward_prefetch: Hiding the AllGather Latency
During the backward pass, FSDP needs the full parameters of a layer before it can compute that layer's gradients. Without prefetching the sequence is:
`
compute gradients for layer N -> AllGather layer N-1 -> compute gradients for layer N-1 -> ...
`
Each AllGather blocks gradient computation until it finishes. The GPU sits idle waiting for the network.
BACKWARD_PRE prefetching overlaps them:
`
compute gradients for layer N
|__ AllGather layer N-1 (starts concurrently)
compute gradients for layer N-1
|__ AllGather layer N-2 (starts concurrently)
`
The AllGather for layer N-1 starts while layer N's gradients are still computing. By the time backward reaches layer N-1, its weights are already reconstructed. This is the same principle as DDP's gradient bucket overlap: start communication before the GPU needs the result.
BACKWARDPOST starts the AllGather after the current layer's gradients are done. Less overlap, but lower peak memory because you never hold two layers' full weights simultaneously. Use BACKWARDPRE unless peak memory is critically tight.
`python
backwardprefetch = BackwardPrefetch.BACKWARDPRE
`
forward_prefetch=True applies the same idea to the forward pass: AllGather the next layer while the current is computing. It is off by default because the memory cost (two layers fully reconstructed simultaneously) is often not worth the throughput gain at normal batch sizes. The training script leaves it off.
---
CPU Offload
CPUOffload(offload_params=True) moves parameter shards to CPU after each FSDP unit finishes its forward or backward pass. Optimizer states live permanently on CPU. Parameters are fetched back to GPU just before the AllGather that needs them.
The memory reduction is substantial: GPU memory approaches just activations and a single layer's full parameters at any time. The throughput cost is also substantial: every AllGather becomes a PCIe transfer (typically 16-32 GB/s on consumer hardware) rather than an NVLink or even PCIe GPU-to-GPU transfer. For a model with many layers, this is a large number of PCIe round trips per step.
The rule of thumb: use CPU offload only when FULL_SHARD across all available GPUs still doesn't fit. The memory it saves is real; the throughput it costs is also real, and on bandwidth-limited systems it can slow training by 3-5x.
`python
cpuoffload = CPUOffload(offloadparams=True) if CPU_OFFLOAD else None
`
---
useorigparams: Required for PEFT
FSDP by default flattens all parameters within a unit into a single 1D tensor. This is an implementation detail that makes sharding simpler: one contiguous buffer per unit, sliced evenly across ranks. The downside is that the original named parameters cease to exist as addressable tensors.
PEFT (and LoRA specifically) injects trainable adapter matrices by name. It accesses model.layers[0].selfattn.qproj.lora_A.weight by traversing the module hierarchy. After FSDP flattens parameters, that path no longer resolves to a tensor; it resolves to a view into the flat buffer, which PEFT cannot work with.
useorigparams=True tells FSDP to preserve the original parameter structure. Each named parameter remains accessible through the module hierarchy; FSDP internally maps it to the appropriate slice of the flat shard. The sharding is unchanged; only the visibility of individual parameters is preserved.
This is also required when using torch.compile with FSDP, since the compiler inspects named parameters to build its graph.
---
Putting It All Together: The FSDP Wrap
Here's the full FSDP wrapping call from the training script, combining every knob discussed above:
`python
wrapped = FSDP(
model,
autowrappolicy=auto_wrap,
shardingstrategy=SHARDINGSTRATEGY,
mixedprecision=mppolicy,
backwardprefetch=backwardprefetch,
cpuoffload=cpuoffload,
deviceid=torch.device(f"cuda:{localrank}"),
useorigparams=True,
forward_prefetch=False,
)
`
---
LoRA with FSDP: Two Gotchas
Using LoRA adapters with FSDP introduces two issues that don't exist in simpler setups.
Mixed dtype in flat params. PEFT initializes LoRA adapter weights in float32 regardless of the base model dtype. FSDP requires uniform dtype within each flat param unit. The base model is BF16, the LoRA weights are FP32, and FSDP raises ValueError: Must flatten tensors with uniform dtype. The fix is to cast the entire model after applying PEFT:
`python
model = getpeftmodel(model, lora_config)
model = model.to(torch.bfloat16)
`
KV cache vs activation checkpointing. With use_cache=True (the HuggingFace default, set for inference), the model creates KV cache tensors during forward that are not reproduced during activation checkpointing's recomputation pass. This causes CheckpointError: A different number of tensors was saved during the original forward and recomputation (87 vs 85). The KV cache is only useful for autoregressive inference; during training all tokens are processed in parallel via causal masking, so there's nothing to cache. The fix:
`python
model.config.use_cache = False
`
Both of these must happen before FSDP wrapping.
---
Why No QLoRA
QLoRA (4-bit NF4 weights from bitsandbytes + LoRA adapters in BF16, used in the → DDP experiment) does not work with FSDP.
FSDP shards parameters by slicing tensors along a dimension and distributing the slices. It expects standard PyTorch tensors. bitsandbytes 4-bit tensors are not standard PyTorch tensors; they are custom quantized buffers that bitsandbytes dequantizes inside its own CUDA kernels at compute time. FSDP has no mechanism to slice them, AllGather them, or run ReduceScatter over them. The two libraries operate at incompatible levels of abstraction.
FSDP-compatible quantization does exist. torchao provides int4 and int8 weight-only quantization using standard PyTorch tensor subclasses, which FSDP can shard, AllGather, and ReduceScatter normally. Answer.ai also demonstrated QLoRA working with PyTorch's newer FSDP2 implementation, which has native tensor subclass support. These are real options for production training where you need both sharding and compression to fit a large model.
This script uses neither. The per-GPU memory footprint is larger than in QLoRA setups (~8 GB for Qwen3-4B weights vs ~2.5 GB in 4-bit), which is precisely why FSDP is relevant here: FULL_SHARD across 2 GPUs brings the per-GPU weight footprint back down to ~4 GB. FSDP pays for that memory reduction with more communication; the tradeoff is explicit and measurable. Adding quantization on top of FSDP would conflate two variables (distributed strategy and model representation), making it impossible to attribute throughput or memory differences to either one cleanly.
---
Activation Checkpointing with FSDP
Standard gradient checkpointing via HuggingFace's model.gradientcheckpointingenable() is not FSDP-aware. It hooks into the model's own forward pass without knowledge of FSDP's AllGather/ReduceScatter schedule, which can cause incorrect recomputation or memory issues when the two interact.
The correct approach is applyactivationcheckpointing from PyTorch's distributed algorithms:
`python
from torch.distributed.algorithms.checkpoint.checkpointwrapper import (
checkpoint_wrapper,
CheckpointImpl,
applyactivationcheckpointing,
)
def applyfsdpactivation_checkpointing(model: FSDP):
nonreentrantwrapper = functools.partial(
checkpoint_wrapper,
checkpointimpl=CheckpointImpl.NOREENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, Qwen3DecoderLayer)
applyactivationcheckpointing(
model,
checkpointwrapperfn=nonreentrantwrapper,
checkfn=checkfn,
)
`
Two rules:
Call it after FSDP() wrapping, not before. The checkpoint wrapper needs to see already-wrapped FSDP modules so recomputation interleaves correctly with FSDP's communication schedule. Wrapping before FSDP means the checkpoint boundaries don't align with FSDP unit boundaries, which wastes the recomputation.
Use NOREENTRANT. The reentrant checkpoint implementation re-enters the autograd engine during recomputation, which interferes with FSDP's hooks that fire during backward. NOREENTRANT avoids this by running recomputation in a fresh forward pass without re-entering autograd.
---
no_sync() and Gradient Accumulation
Gradient accumulation runs multiple forward/backward passes before stepping the optimizer. By default, FSDP runs AllGather and ReduceScatter on every backward call. With 8 accumulation steps, that is 7 rounds of collectives whose results are thrown away.
model.nosync() suppresses all collectives on intermediate steps (the same API as → DDP's nosync(), but with broader effect). In FSDP with FULL_SHARD, this means no AllGather during backward (no need to reconstruct weights for a gradient that won't be reduced) and no ReduceScatter. Intermediate steps touch no collective operations at all. Only the final step pays the full communication cost.
`python
isaccumulating = (step + 1) % GRADACCUM != 0
synccontext = model.nosync() if is_accumulating else contextlib.nullcontext()
with sync_context: outputs = model(inputids=inputids, labels=labels) loss = outputs.loss / GRAD_ACCUM loss.backward()
if not is_accumulating:
torch.nn.utils.clipgradnorm(model.parameters(), MAXGRAD_NORM)
optimizer.step()
scheduler.step()
optimizer.zerograd(setto_none=True)
`
One detail: clipgradnorm_ on an FSDP model internally AllGathers the per shard gradient norms to compute the global norm. This is one extra collective per optimizer step, transparent to the caller.
---
Saving: State Dict Modes
Each GPU holds only 1/N of every parameter tensor. model.state_dict() on a single rank returns a shard, not a complete checkpoint.
FSDP provides three state dict modes:
FULLSTATEDICT: FSDP AllGathers all shards to rank 0. Rank 0 receives the complete state dict; all other ranks receive an empty dict. This is the right mode for saving an inference checkpoint or a final adapter.
FullStateDictConfig(offloadtocpu=True, rank0only=True) is almost always the right config to pair with it: offloadtocpu streams each layer's gathered tensor to CPU immediately after AllGather rather than accumulating all layers on rank 0's GPU (which would require fullmodelsize of GPU memory on rank 0 just to save), and rank0only avoids sending the full state dict to every rank unnecessarily.
`python
cfg = FullStateDictConfig(offloadtocpu=True, rank0_only=True)
with FSDP.statedicttype(model, StateDictType.FULLSTATEDICT, cfg):
fullstate = model.statedict()
if rank == 0:
adapterstate = getpeftmodelstatedict(model.module, fullstate)
model.module.savepretrained(savepath, statedict=adapterstate)
tokenizer.savepretrained(savepath)
`
model.module unwraps the FSDP wrapper to get the underlying PEFT model. getpeftmodelstatedict filters the full state dict to only the LoRA adapter keys, which is all that needs to be saved for continued finetuning or inference.
SHARDEDSTATEDICT: each rank saves its own shard independently. Fast and memory-efficient (no AllGather, no rank-0 bottleneck), but requires loading with the same world size. Good for mid-run checkpoints during long training runs where resumability matters more than portability.
LOCALSTATEDICT: raw flat tensors without FSDP metadata. Mostly for debugging.
---
The Full Script
The complete training script is at experiments/fsdp-ft/train.py. Run it:
`bash
torchrun --nprocpernode=2 experiments/fsdp-ft/train.py
`
To compare ZeRO stages, change SHARDINGSTRATEGY at the top of the script and re-run. The CSV logger writes a strategy column (fsdpzero2, fsdp_zero3, etc.) so all runs can be loaded into a single dataframe.
---
Results
From our training runs on 2x RTX 4090 (Vast.ai, PCIe), using FULL_SHARD (ZeRO-3), compared against DDP and single-GPU baselines from the DDP post:
| Metric | Single GPU (QLoRA) | DDP 2x GPU (QLoRA) | FSDP 2x GPU (BF16) | |---|---|---|---| | Tokens/sec | ~1,650 | ~3,000-3,200 | ~140 | | Step time | ~7.3s | ~4-5s | ~43.5s | | Peak VRAM/GPU | 5.58 GB | ~6-7 GB | 14.8 GB |
The throughput gap is large and expected. DDP replicates the full model on each GPU and communicates only gradients (one AllReduce per step, overlapped with backward). FSDP shards parameters, gradients, and optimizer states, then reconstructs them on the fly -- AllGather before every forward layer, AllGather + ReduceScatter during every backward layer, plus activation checkpointing reruns the forward during backward. On 2 GPUs over PCIe, that communication dominates compute. The other half of the gap is the workload itself: DDP trains QLoRA (4-bit base, ~1% trainable), while FSDP trains full BF16 (100% trainable, 16x more gradient bytes per step).
VRAM settled at 14.8 GB after the first step and stayed flat. FSDP shards the 8 GB BF16 model to ~4 GB per GPU, but gradients, optimizer states, and activations add back up. Activation checkpointing is mandatory at this memory budget: without it, storing all 36 layers' activations pushes past 24 GB.
On NVLink (8+ GPUs within an NVSwitch-connected node), the communication-to-compute ratio becomes much more favorable and the throughput gap narrows. The crossover point where FSDP becomes the right tool is when the model genuinely cannot fit on each GPU -- quantization stops being applicable (models above ~13B that push past 24 GB even in 4-bit) or full finetuning is required and the model exceeds single-GPU memory.