Tensor Parallelism and Sequence Parallelism

April 20, 2026

*Part 3 of 4 — Distributed Training series. Start with the overview.*

1. Distributed Data Parallel: How It Actually Works — Scale throughput when the model fits on one GPU. 2. ZeRO and FSDP: Model Sharding — Fit models that don't fit, by sharding weights, gradients, and optimizer state. 3. Tensor Parallelism and Sequence Parallelism (this post) — Shard inside each layer; wins when interconnect is slow or a single layer is oversized. 4. Pipeline Parallelism: How It Actually Works — Shard across layers; the axis that cheaply spans nodes.

Training a 70-billion-parameter model on a single GPU is not an engineering challenge. It is simply impossible. The weights alone (in bfloat16) occupy 140 GB. The activations during a forward pass can dwarf even that. Distributed training is not an optimization; it is a prerequisite.

Tensor Parallelism (TP) is one of the oldest and most effective answers to this problem: split the weight matrices themselves across multiple devices, so each GPU holds only a fraction. It works beautifully. But it has a cost that only became painful as context windows grew: even as the weights are sharded, the activations flowing into each layer are fully replicated on every device. As sequence lengths grew -- 4k, 8k, 32k, 128k tokens -- this hidden cost became impossible to ignore.

TP is not the only answer to the memory problem. → FSDP shards parameters, gradients, and optimizer state across data parallel replicas and reconstructs full layers on demand via AllGather and ReduceScatter, which is a very different approach from splitting the computation *inside* each layer. The two are not substitutes. FSDP issues a few large batched collectives per step while TP issues many small ones (sized like hidden_size, not full layers), so on GPUs connected over PCIe the small AllReduces from TP are cheap, and on NVLink the batched collectives from FSDP amortize better. In production the two are usually composed: TP within a node where NVLink makes the small collectives free, and FSDP across nodes where batched collectives hide the inter node latency. This 2D parallelism is exactly what Megatron LM and PyTorch's composable distributed APIs are built for, and Sequence Parallelism is a standard piece of that stack. The TP versus FSDP throughput and VRAM numbers later in this post should be read with that in mind. They are a snapshot of one corner of the design space (two 4090s on PCIe, LoRA, short sequences), not a verdict that one beats the other.

Sequence Parallelism (SP) is the fix. Introduced in Megatron-LM's 2022 paper, it shards exactly those activations that TP leaves behind. Together, TP+SP achieves something neither can alone: both weights and activations sharded, with the same communication cost as TP alone.

This post walks through how TP works inside a transformer, how SP completes it, and what happens when you try to combine both with LoRA in practice. The code is from a training script (experiments/tp-ft/train.py) that finetunes Qwen3-4B on AMI meeting transcripts with LoRA on 2x RTX 4090s connected via PCIe.

| | | |---|---| | Model | Qwen3-4B (BF16) | | Method | LoRA (r=16, alpha=32, all linear layers) | | Dataset | knkarthick/AMI (237 samples) | | GPUs | 2x RTX 4090 (PCIe, Vast.ai) | | TP degree | 2 | | Effective batch size | 16 (batch=1, grad accum=16, no data parallelism) | | Sequence length | 256 (reduced from 512 due to gradient checkpointing incompatibility; see below) | | Optimizer | AdamW (lr=2e-4, cosine decay to 2e-5) | | Epochs | 3 |

---

TP Inside a Transformer Layer

TP splits each weight matrix W across N GPUs (the TP degree) in one of two ways. Column-parallel shards W along its output dimension, so each GPU computes a different slice of the output with no communication. Row-parallel shards W along its input dimension, so each GPU computes a partial sum that an AllReduce then reduces into the full output. Pairing them column-then-row keeps the intermediate activation sharded across the element-wise op in the middle (GELU, SiLU) and collapses the whole sub-block to a single AllReduce at the boundary.

A transformer decoder layer has two blocks: self-attention and MLP. Each has a natural column-to-row structure.

Attention

The attention block has four projections (PyTorch nn.Linear weight shapes, [out, in]):

` qproj: [numheads * headdim, hiddensize] -> ColwiseParallel kproj: [numkvheads * headdim, hidden_size] -> ColwiseParallel vproj: [numkvheads * headdim, hidden_size] -> ColwiseParallel oproj: [hiddensize, numheads * headdim] -> RowwiseParallel `

qproj, kproj, vproj are column-parallel. Each GPU computes projections for numheads/N attention heads. The attention computation (softmax, value weighting) runs independently per head, so each GPU processes its local heads without communication. o_proj is row-parallel: it takes the head outputs (already split across GPUs, one slice per GPU) and produces the full hidden-size output via AllReduce.

One AllReduce per attention block. The Q/K/V projections, attention scores, and value weighting all run communication-free.

For Qwen3-4B with TP=2: 32 attention heads become 16 per GPU. 8 KV heads become 4 per GPU. Both divide evenly.

MLP (SwiGLU)

Qwen3's SwiGLU MLP follows the same pattern: gateproj and upproj are column-parallel, down_proj is row-parallel, one AllReduce per block.

Total Communication Per Layer

Two AllReduces per decoder layer: one in attention (after oproj), one in MLP (after downproj). The count is exactly two because of the column-then-row pairing: every other op in the layer (Q/K/V projections, per-head attention math, SiLU, the gate/up element-wise multiply, residual adds, RMSNorm) is either head-local, element-wise, or replicated, so AllReduce is only needed at the row-parallel boundary that closes each sub-block. Naively making every linear independently row- or column-parallel would need an AllReduce after all 7 projections per layer; the pairing collapses it to 2.

For Qwen3-4B with 36 layers, that's 72 AllReduces per forward pass and 72 during backward. Compare this to → FSDP, which does one AllGather + one ReduceScatter per FSDP unit (typically per layer) -- fewer but larger collectives.

---

The Hidden Problem: Replicated Activations

Look at the forward pass again. Before the TP-sharded weights are touched, input X is broadcast to all N devices. After the AllReduce, the output also lives in full on every device. The LayerNorm and Dropout operations sit outside the TP region -- they operate on fully replicated activations.

For short sequences, this is fine. For a 128k-token context with hidden dimension 8192 in bfloat16:

`python b = 1 # batch size s = 131072 # 128k tokens d = 8192 # hidden dim dtype = 2 # bytes (bfloat16)

mem = b * s * d * dtype # = 2 GB per device, N times over `

With TP degree 8, that's 2 GB of identical data sitting on 8 devices. As layers stack, this becomes the dominant memory consumer -- not the weights. TP shards the weights perfectly but leaves the activations untouched.

---

Sequence Parallelism: Sharding What TP Leaves Whole

The key insight is simple: the operations outside the TP region (LayerNorm, Dropout, residual add) are independent across the sequence dimension. Token 0's LayerNorm result has no dependency on token 4096's. So why not shard along the sequence?

` Residual stream: [b, s/N, d] <-- sequence-sharded (SP region) | All-Gather -> [b, s, d] <-- full sequence reconstructed | Tensor-parallel Attention / MLP | Reduce-Scatter -> [b, s/N, d] <-- re-sharded | Dropout + Residual (on s/N tokens, no redundancy) | Residual stream: [b, s/N, d] <-- back to SP region `

The transformer layer now has a natural rhythm: activations enter in sequence-sharded form, an AllGather reconstructs the full sequence for tensor-parallel compute, and a ReduceScatter at the output breaks it back into sequence shards. LayerNorm and Dropout operate on the sharded view, each device processes only s/N tokens.

AllGather + ReduceScatter = AllReduce

This might sound like more communication. It isn't. The AllGather and ReduceScatter together are mathematically equivalent to an AllReduce in total bytes moved. Standard TP uses an AllReduce after each row-parallel layer. TP+SP replaces that AllReduce with an AllGather before the TP region and a ReduceScatter after. Same bytes, different layout.

| Property | TP only | TP + SP | |---|---|---| | Weight memory per device | 1/N | 1/N | | TP-region activations per device | 1/N | 1/N | | LayerNorm / Dropout activations | Full (replicated) | 1/N (seq-sharded) | | Communication per layer | 2x AllReduce | AllGather + ReduceScatter | | Communication volume | 2(N-1)/N x msg | 2(N-1)/N x msg |

The communication cost is identical. The memory saving is the entire point: activations outside the TP region drop from full replication to 1/N. For long-context training, this is the difference between fitting a batch and not fitting at all.

---

The PyTorch API

PyTorch's torch.distributed.tensor.parallel.parallelize_module takes a module, a DeviceMesh, and a plan that maps submodule names to strategies. The plan for one Qwen3 decoder layer is just:

`python tp_plan = { "selfattn.qproj": ColwiseParallel(), "selfattn.kproj": ColwiseParallel(), "selfattn.vproj": ColwiseParallel(), "selfattn.oproj": RowwiseParallel(), "mlp.gate_proj": ColwiseParallel(), "mlp.up_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(), } for layer in model.model.layers: parallelizemodule(layer, devicemesh, tp_plan) `

parallelize_module replaces each targeted weight with a DTensor shard and the runtime inserts the AllReduces automatically. There is also a SequenceParallel() strategy for the RMSNorms that would shard the activations at the norm boundaries, which is the SP story from the previous section. In theory it completes the TP memory picture; in practice it breaks with LoRA, for reasons covered next.

uselocaloutput (default True) on ColwiseParallel/RowwiseParallel controls whether each layer's output is unwrapped from DTensor back to a regular torch.Tensor immediately after the collective. When True, downstream code (.view() for heads, .transpose(), RMSNorm, silu) works on plain tensors without needing DTensor support. When False, the output stays as a DTensor and every downstream operation must handle DTensor semantics.

This parameter is the precise reason the LoRA path collides with the base layer's output later. With uselocaloutput=True, ColwiseParallel returns a plain tensor, but the LoRA branch (operating on DTensor inputs with DTensor weights) produces a DTensor. PEFT's LoraLinear.forward adds the two together, and PyTorch refuses to mix types in the addition. Setting it to False would fix the addition but break every .view() and .transpose() in Qwen3's attention code.

---

TP vs FSDP: Different Tradeoffs

The collective profile and the production-stack story are in the introduction. Two more differences are worth pinning down before the LoRA section: what each strategy actually shards, and what each strategy expects of the data loader. (The → FSDP post covers the FSDP side of this comparison in its own "FSDP vs TP+SP on PCIe" section.)

What Gets Sharded

| Component | FSDP (FULL_SHARD) | TP | |---|---|---| | Linear weights | Sharded (flat 1D slices) | Sharded (column/row slices) | | Embeddings | Sharded | Replicated | | LM head | Sharded | Replicated | | Layer norms | Sharded (as buffers) | Replicated | | Optimizer states | Sharded | Per local shard only | | Gradients | Sharded | Per local shard only |

FSDP shards everything. TP only shards linear layers. For Qwen3-4B, the embedding table (embedtokens) is 151,936 x 2,560 x 2 bytes = ~740 MB, and lmhead is the same size. Both are fully replicated on every GPU with TP. Embedding parallelism is possible (PyTorch's TP API supports it) but adds complexity for a relatively small memory saving at 4B scale. This is why TP uses more memory per GPU than FSDP at the same model size, and why SP on the norms and activations matters so much.

What Gets Communicated (and Why It Matters on PCIe)

FSDP and TP+SP solve the same problem — fitting a large model across multiple GPUs — but they move fundamentally different things over the interconnect, and that difference is everything on commodity hardware.

FSDP shards the model weights across GPUs but must reconstruct the full layer on every GPU before each forward and backward pass. Per training step it moves roughly 3x the model size across the interconnect: AllGather in forward, AllGather in backward, ReduceScatter of gradients in backward. For a 4B parameter model in BF16 that is ~24 GB of weight traffic per step.

TP+SP keeps each GPU's weight shard in place and never gathers it. The only thing that crosses the interconnect is *activations*: the intermediate [batch, seq, hidden] tensors produced inside each layer. At typical training shapes (batch=1, seq=512, hidden=2560, BF16) each per-layer collective is a few MB, and the total per step lands around 1-2 GB regardless of model size.

On NVLink (600-900 GB/s) FSDP's 24 GB is manageable and overlaps with compute. On PCIe (~25 GB/s effective on 2x 4090) it becomes the bottleneck: weight traffic alone accounts for the majority of step time, and the AllGathers serialize because there is no second link to overlap them on. TP+SP is nearly compute-bound on the same hardware. Measured on our box, FSDP runs at ~43s/step while TP+SP runs at ~7s/step on the same model and batch.

The rule that falls out of this: if you do not have NVLink between GPUs, TP+SP is the correct choice for any model where parameters significantly outnumber activation volume, which is true of essentially every LLM at normal training batch sizes. FSDP only catches up when you have an interconnect fast enough to make weight traffic free.

Data Semantics

→ FSDP is data parallel: each GPU processes a different batch. DistributedSampler splits the dataset (see → DDP: DistributedSampler). Effective batch size = pergpubatch * worldsize * gradaccum.

TP is model parallel: all GPUs process the same batch. Each GPU holds only a slice of each weight matrix, so every GPU must run its slice of the matmul on the same input for the partial results to combine correctly in the AllReduce. No DistributedSampler. All ranks must see identical data in the same order. Effective batch size = pergpubatch * grad_accum.

This is why our training script sets GRAD_ACCUM=16 instead of FSDP's 8: without data parallelism, the effective batch size would be halved. The explicit seeding ensures all ranks shuffle identically:

`python g = torch.Generator() g.manual_seed(42)

loader = DataLoader( dataset, batchsize=BATCHSIZE, shuffle=True, collatefn=collatefn, pin_memory=True, num_workers=0, # workers inherit RNG state; 0 avoids nondeterministic ordering generator=g, ) `

The generator is then reseeded once per epoch inside the training loop so each epoch's shuffle is different but still identical across ranks:

`python for epoch in range(EPOCHS): g.manual_seed(42 + epoch) ... `

num_workers=0 avoids nondeterministic ordering from multiprocess data loading.

---

LoRA with TP: The PEFT Ordering Problem

TP and PEFT interact in ways that require careful ordering.

Apply TP First, Then LoRA

parallelizemodule needs to see plain nn.Linear layers. PEFT wraps Linear layers in LoraLinear, which has a baselayer attribute and separate loraA/loraB modules. parallelize_module doesn't know how to descend into LoraLinear and shard the inner base layer.

The training script applies TP to the base model first, which converts Linear weights to DTensors. Then PEFT is applied on top. PEFT sees the sharded dimensions and creates LoRA adapters sized to match:

`python model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN, lowcpumem_usage=True, )

model.config.use_cache = False

Move to GPU before TP sharding.

model = model.to(f"cuda:{local_rank}")

Apply TP to the base model before LoRA.

applytp(model, devicemesh)

LoRA on top of the TP-sharded model.

model = getpeftmodel(model, lora_config) `

For a ColwiseParallel qproj with TP=2, the base weight goes from [4096, 2560] to [2048, 2560] per GPU. PEFT creates loraB with shape [2048, 16] instead of [4096, 16]. The LoRA adapters are inherently shard-sized.

Why Sequence Parallelism Breaks With LoRA

We initially included SequenceParallel() on the RMSNorm layers -- the theoretically correct approach that would shard activations at norm boundaries and eliminate the redundant memory. It didn't work.

The first error was a DTensor/Tensor mismatch in the LoRA matmul. With SP active, the input x to each layer is a DTensor (sharded on the sequence dimension). The LoRA adapter weights are regular tensors (created by PEFT after TP). PyTorch refuses to mix DTensors and regular tensors in a matmul: RuntimeError: aten.mm.default got mixed torch.Tensor and DTensor.

We fixed this by converting all LoRA weights to Replicate DTensors after PEFT initialization:

`python from torch.distributed.tensor import DTensor, Replicate

for module in model.modules(): for name, param in module.named_parameters(recurse=False): if param.requires_grad and not isinstance(param.data, DTensor): module._parameters[name] = torch.nn.Parameter( DTensor.fromlocal(param.data, devicemesh, [Replicate()]), requires_grad=True, ) `

That fixed the matmul. The next error was one level up -- in the addition. PEFT's LoraLinear.forward does:

`python result = self.base_layer(x) # ColwiseParallel returns local tensor result = result + loraB(loraA(dropout(x))) * scaling # LoRA path returns DTensor `

With uselocaloutput=True (the default), ColwiseParallel converts the base layer's output from DTensor back to a regular tensor. But the LoRA path operates on the DTensor input x with Replicate DTensor weights, producing a DTensor output. Adding a regular tensor and a DTensor crashes: RuntimeError: aten.add.Tensor got mixed torch.Tensor and DTensor.

The fundamental tension: uselocaloutput=True is needed for downstream operations (.view(), .transpose(), RMSNorm) to work on regular tensors. But it creates a type mismatch in LoRA's residual addition. Setting uselocaloutput=False would keep everything as DTensors, but then every downstream operation needs DTensor support -- and Qwen3's attention code (view, transpose, repeat_kv) wasn't written for DTensors.

We dropped SequenceParallel. Without it, all activations stay as regular tensors, and the LoRA path works without issues. The cost is that norms run redundantly on every GPU and activations outside the TP region are fully replicated.

This is a real limitation of the current TP + LoRA ecosystem. Megatron-LM's SP implementation works because it controls the entire model code and can manage the DTensor/Tensor boundary explicitly. When using HuggingFace models + PEFT + PyTorch's TP API, the abstractions collide at the LoRA addition.

There is no upstream fix as of late 2025. PEFT issue #2394 hits the same ColwiseParallel rejection of LoraLinear and was closed in April 2025 with three user-side workarounds (apply LoRA before TP, manually rewrite the tpplan to target baselayer), none of which address the SP case described above. The torchtune team has the same gap (discussion #2454, unresolved as of March 2025), and HuggingFace transformers' newer tp_plan API was explicitly scoped to inference, not training. Every team is pointing at every other team; nobody owns the integration. The hand-rolled script in the next section is the working reference for what a proper PEFT-side fix would look like.

Saving Sharded LoRA Weights

Because LoRA adapters are shard-sized, the checkpoint is not portable. Each rank saves its own shard:

`python def saveadapter(model, tokenizer, savepath: str, rank: int): adapter_state = {} for name, param in model.named_parameters(): if param.requires_grad: adapter_state[name] = param.detach().cpu()

rankpath = f"{savepath}/rank_{rank}" os.makedirs(rankpath, existok=True) torch.save(adapterstate, f"{rankpath}/adapter_model.bin")

if rank == 0: tokenizer.savepretrained(savepath) print(f"saved adapter shards -> {savepath}/rank*/")

dist.barrier() `

To use the checkpoint for single-GPU inference, the shards must be reassembled. The inference script (experiments/tp-ft/infer.py) does this by concatenating:

  • ColwiseParallel lora_B shards along dim 0 (output features)
  • RowwiseParallel lora_A shards along dim 1 (input features)
  • All other LoRA matrices are identical across ranks
  • ---

    Gradient Checkpointing: A Memory Leak With TP

    Gradient checkpointing saves memory by discarding activations during the forward pass and recomputing them during backward. It works well with → FSDP (using PyTorch's FSDP-aware applyactivationcheckpointing). It does not work well with TP.

    The recomputed forward pass during backward triggers TP's functional collectives (AllReduce) a second time for every layer. These collectives are async -- implemented through c10d_functional, which returns future tensors. The output buffers from these recomputed collectives were not being freed between gradient accumulation micro-steps. Memory grew by ~1.3 GB per optimizer step, consistently, until OOM:

    ` step 1: 5.7 GB step 2: 7.0 GB step 3: 8.3 GB step 4: 9.6 GB step 5: 10.7 GB ... step 15: 22.3 GB -> OOM `

    We tried three progressively aggressive cleanup approaches after each micro-step: 1. del outputs, loss -- drop Python references to the computation graph 2. + torch.cuda.synchronize() -- force all pending async collectives to complete 3. + torch.cuda.empty_cache() -- return freed blocks to CUDA

    None of them stopped the growth. The collective buffers from the recomputed forward were being held by something in the autograd graph that none of these cleanup approaches could reach.

    The fix: disable gradient checkpointing entirely. Without it, activations are stored (not recomputed), so the recomputation-triggered collectives never happen. Memory is higher at step 1 but completely flat across all subsequent steps:

    ` step 1: 4.6 GB step 2: 4.6 GB step 3: 4.6 GB ... step 45: 4.6 GB `

    The tradeoff: without checkpointing, we had to reduce MAX_LENGTH from 512 (used in the FSDP experiment) to 256 to fit in 24 GB. TP replicates embeddings and norms (unlike FSDP which shards them), and without checkpointing the full activation stack must fit in memory. This is exactly the kind of problem SP is designed to solve -- and exactly the problem we couldn't use SP for because of LoRA.

    ---

    Making SP Work: A Hand-Rolled Version

    The two limitations we hit in experiments/tp-ft/ -- no Sequence Parallelism, and a MAXLENGTH capped at 256 because activations dominate -- have the same root cause. PEFT's LoraLinear and PyTorch's parallelizemodule operate at incompatible levels of abstraction. Every time you try to push them together, the DTensor/Tensor boundary cuts through code one of them owns and neither can fix.

    Production training stacks (Megatron-LM, NeMo) don't have this problem because they don't try to compose two opaque abstractions. They write the parallel linears themselves, and they write LoRA -- if they use it -- directly on top of those linears. The base layer and the adapter share one coordinate system because they were designed together.

    To see what that looks like in practice, experiments/tp-sp-ft/ is a hand-rolled version of the same Qwen3-4B + LoRA + TP + SP setup. It's about 800 lines, in one file, with no torch.distributed.tensor.parallel, no DTensor, and no PEFT.

    The Idea

    If you never use a DTensor, the DTensor/Tensor mismatch can't happen. So we give up the high-level API, store TP-sharded weights as plain nn.Parameters, and write the collectives ourselves as torch.autograd.Function subclasses. There are exactly three of them, and each is about ten lines:

    `python class _AllGatherSeq(torch.autograd.Function): """Forward: AllGather along dim 1 (seq). Backward: ReduceScatter along dim 1."""

    @staticmethod def forward(ctx, x): outshape = list(x.shape); outshape[1] *= tp_world() out = torch.empty(out_shape, dtype=x.dtype, device=x.device) dist.allgatherintotensor(out, x.contiguous(), group=tpgroup()) return out

    @staticmethod def backward(ctx, grad_out): inshape = list(gradout.shape); inshape[1] //= tpworld() gradin = torch.empty(inshape, dtype=gradout.dtype, device=gradout.device) dist.reducescattertensor(gradin, gradout.contiguous(), group=tp_group()) return grad_in `

    ReduceScatterSeq is the same with forward and backward swapped, because the dual of an AllGather along seq is a ReduceScatter along seq. AllReduceSum has an identity backward, because after the forward AllReduce every rank already holds the same y and therefore the same ∂L/∂y, which is exactly the gradient each rank needs to return for its own input.

    These three functions are the entire SP machinery. Everything else is plumbing.

    LoRA Inside the Parallel Linear

    The trick that makes LoRA work without the DTensor/Tensor mismatch is to define LoRAColumnParallelLinear and LoRARowParallelLinear as subclasses of the parallel linears, with loraA and loraB shaped to match the base weight's shard layout:

    | Base layer | loraA | loraB | |---|---|---| | ColumnParallelLinear (weight [out/N, in]) | [r, in] replicated | [out/N, r] sharded along dim 0 | | RowParallelLinear (weight [out, in/N]) | [r, in/N] sharded along dim 1 | [out, r] replicated |

    The col-parallel case is trivial: both the base path and the LoRA path produce [b, s, local_out], so the addition is type-consistent.

    The row-parallel case is the interesting one. lora_A only sees the local input shard, so its output is a partial sum over feature contributions. We could AllReduce the final [b, s, out] LoRA output, but that's wasteful -- it's much cheaper to AllReduce the small [b, s, r] intermediate and then do the second matmul:

    `python class LoRARowParallelLinear(RowParallelLinear): def forward(self, x): base_partial = F.linear(x, self.weight, None) # [b, s, out] partial lorapartial = F.linear(self.loradropout(x), self.lora_A) # [b, s, r] partial lorafull = allreducesum(lorapartial) # [b, s, r] full loraoutfull = F.linear(lorafull, self.loraB) * self.scaling # [b, s, out] full

    # Both paths must end up in the same SP-sharded layout. base = reducescatterseq(base_partial) # [b, s/N, out] slocal = loraoutfull.shape[1] // tpworld() start = tprank() * slocal lora = loraoutfull[:, start:start + s_local, :] # [b, s/N, out] return base + lora `

    Two collectives in one forward (AllReduce on [b, s, r] plus ReduceScatter on [b, s, out]), where r is the LoRA rank (16) and out is hidden_size (2560). The r=16 AllReduce is essentially free.

    The Decoder Layer Has To Be Rewritten Too

    PEFT could be replaced surgically -- but parallelize_module couldn't, because the whole reason we're doing this is to keep activations in seq-sharded form between layers. HuggingFace's Qwen3DecoderLayer.forward doesn't know about that sharding and uses regular tensors throughout. So the decoder layer is rewritten as a TPDecoderLayer whose forward looks like:

    `python def forward(self, xseqsharded): h = xseqsharded + self.selfattn(self.inputlayernorm(xseqsharded)) h = h + self.mlp (self.postattentionlayernorm(h)) return h `

    The norms operate on [b, s/N, hidden] (per-token, so SP-safe). The attention block's column-parallel q/k/vproj gatherseq internally to recover [b, s, hidden] for the matmul; oproj reducescatterseq back to [b, s/N, hidden]. The MLP block does the same with gateproj/upproj/downproj. The seq dimension is sharded everywhere outside the attention/MLP TP regions, and the activations that tp-ft/ was replicating across all GPUs are now divided by N.

    Add a hand-written RMSNorm, RoPE (the standard interleaved-pairs convention Qwen3 uses), GQA-aware attention with TP head splits, and a HuggingFace checkpoint loader that maps model.layers.0.selfattn.qproj.weight into the right shard on each rank, and you have a complete training script.

    What You Get

    The two scripts now demonstrate complementary lessons:

    | | tp-ft/ (PEFT + parallelize_module) | tp-sp-ft/ (hand-rolled) | |---|---|---| | Sequence Parallelism | No (LoRA + DTensor mismatch) | Yes | | MAX_LENGTH | 256 (activations dominate) | 512 (SP shards them) | | LoRA implementation | PEFT's LoraLinear | Baked into the parallel linears | | TP backend | torch.distributed.tensor.parallel (DTensor) | Plain torch.Tensor + autograd Functions | | Decoder layer | HuggingFace Qwen3 layer (unmodified) | Rewritten end-to-end | | Lines of code | ~500 | ~800 | | What you learn | How parallelizemodule is used | What parallelizemodule is hiding |

    The hand-rolled version is not faster on a 4B model on two PCIe-connected RTX 4090s -- the overhead of vanilla PyTorch's RMSNorm and the un-fused QKV projection roughly cancels out the activation-memory savings from SP. What it buys is the ability to actually run at long context: at MAX_LENGTH=512 the activation stack fits because SP shards it across both GPUs, where tp-ft/ would OOM.

    What You Give Up

    The hand-rolled version is also a teaching artifact, not a production stack. Megatron-LM and NeMo wrap their parallel linears around fused CUDA kernels: Apex's RMSNorm, Transformer Engine's FP8 GEMMs, fused QKV projection (one matmul instead of three), flash-attention with TP-aware masking. We use vanilla PyTorch RMSNorm, a separate matmul for qproj/kproj/v_proj, and SDPA. On the same hardware, a real Megatron run would be ~2x faster than tp-sp-ft/. The script exists to show the structure clearly, not to set throughput records.

    The other thing you give up is portability across model architectures. tp-ft/ works on any HuggingFace model whose Linear layers can be parallelize_module'd -- Llama, Mistral, Qwen, Phi, etc. -- because PyTorch's TP API treats them as black-box nn.Linears. tp-sp-ft/ only works on Qwen3 because the decoder layer is hand-coded against Qwen3's specific architecture (GQA ratio, per-head QK norms, RoPE convention, SwiGLU MLP). To support a new model you have to rewrite the decoder layer and the weight loader. This is exactly the trade-off that production stacks make: Megatron-LM has separate model files for Llama, Mistral, GPT, and so on, and adding a new architecture is a real engineering project.

    This is the lesson. The high-level API gives you portability and saves code at the cost of expressiveness. The low-level approach gives you full control over the data flow at the cost of having to rewrite code per model. Neither is wrong; they're for different situations. Knowing both lets you pick.

    ---

    Results

    From our training run on 2x RTX 4090 (Vast.ai, PCIe):

    | Metric | TP (tp-ft) | TP + SP (tp-sp-ft) | → FSDP (for comparison) | |---|---|---|---| | Tokens/sec | ~850 avg | ~1870 avg | ~140 avg | | Step time | ~7.2s | ~6.6s | ~43.5s | | VRAM per GPU | 4.6 GB (flat) | 6.0 GB (flat) | 14.8 GB (flat) | | Sequence length | 256 | 512 | 512 | | Gradient checkpointing | No (leaks memory) | No | Yes | | Total training time | ~5.5 minutes | ~5 minutes | ~33 minutes |

    The hand-rolled tp-sp-ft run is the apples-to-apples comparison against FSDP: same 512-token sequence length, no gradient checkpointing, and Sequence Parallelism doing exactly what tp-ft could not. At that setting it processes ~13x more tokens per second than FSDP for ~2.5x less VRAM.

    The tp-ft column should be read carefully. It looks 6x faster than FSDP and 3x more memory efficient, but the sequence length is halved and there is no gradient checkpointing. Shorter sequences mean less work per token (smaller attention windows, smaller activation stack), and the 4.6 GB only fits because every activation is for 256 tokens instead of 512. If tp-ft tried to run at 512 it would OOM without checkpointing, and with checkpointing it leaks ~1.3 GB per step. The whole reason tp-sp-ft exists is that SP shards exactly those activations and lets the same hardware run at the longer context.

    So the headline numbers are a snapshot of one corner of the design space, not a verdict. As the introduction laid out, Tensor Parallelism and FSDP are not substitutes: they solve different problems and in production they get composed, with Sequence Parallelism filling in the activation gap that tp-ft could not close because of the LoRA collision and that tp-sp-ft had to drop down to plain torch.Tensor to fix.

    ---

    The Full Scripts

    Two complete training scripts ship with this post.

    experiments/tp-ft/train.py uses PyTorch's parallelizemodule and PEFT. ~500 lines. No Sequence Parallelism (the LoRA + DTensor mismatch). Runs at MAXLENGTH=256. This is the script the bulk of this post is about.

    `bash uv run torchrun --nprocpernode=2 experiments/tp-ft/train.py uv run python experiments/tp-ft/infer.py `

    experiments/tp-sp-ft/train.py is the hand-rolled, Megatron-style version. ~800 lines. Sequence Parallelism actually works. Runs at MAXLENGTH=512. Read this one when you want to see what parallelizemodule was hiding.

    `bash uv run torchrun --nprocpernode=2 experiments/tp-sp-ft/train.py uv run python experiments/tp-sp-ft/infer.py `

    The inference scripts reassemble the sharded LoRA checkpoints onto a single GPU -- tp-ft/infer.py builds a base model and attaches the reassembled adapter; tp-sp-ft/infer.py merges the adapter directly into the base weights so generation runs through stock model.generate().