askill
fsdp

fsdpSafety 95Repository

PyTorch FSDP — sharding strategies, mixed precision, activation checkpointing, auto_wrap_policy, checkpointing, and HuggingFace integration. Use when training models too large for a single GPU with FSDP. NOT for DeepSpeed (see deepspeed).

3 stars
1.2k downloads
Updated 3/7/2026

Package Files

Loading files...
SKILL.md

FSDP (Fully Sharded Data Parallel)

PyTorch FSDP shards model parameters, gradients, and optimizer states across GPUs — enabling training of models that don't fit on a single GPU. Part of torch.distributed.fsdp. PyTorch 2.6+ (latest stable: 2.10).

FSDP2 (fully_shard): PyTorch 2.6+ includes FSDP2 — a per-parameter sharding API (from torch.distributed.fsdp import fully_shard) that replaces the wrapper-based FSDP1 (FullyShardedDataParallel). FSDP2 is now the recommended API: simpler usage, better frozen parameter support, and communication-free sharded state dicts. This skill covers both APIs; FSDP2 examples are noted where applicable.

When to use FSDP vs DDP: Use DDP when the model fits on one GPU. Use FSDP when it doesn't (typically >10B parameters, or large batch sizes exceeding single-GPU memory).

Core Concepts

FSDP shards a model's parameters across N GPUs. During forward/backward:

  1. All-gather parameters for the current layer (briefly full on each GPU)
  2. Compute forward/backward
  3. Reduce-scatter gradients
  4. Discard non-local shards

This trades communication for memory — each GPU only stores 1/N of parameters + optimizer state.

Sharding Strategies

StrategyMemory SavingsCommunicationUse Case
FULL_SHARDMaximum (params + grads + optimizer)HighestDefault — models that don't fit
SHARD_GRAD_OPModerate (grads + optimizer only)LowerModel fits but optimizer doesn't
NO_SHARDNone (equivalent to DDP)LowestBaseline / debugging
HYBRID_SHARDFull shard within node, replicate acrossBalancedMulti-node with fast intra-node links

Basic Setup

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
    BackwardPrefetch,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    size_based_auto_wrap_policy,
)
import functools

def train(local_rank: int):
    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)

    model = build_model()

    # Auto-wrap policy: each transformer layer becomes an FSDP unit
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock},  # your layer class
    )

    # Mixed precision
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    )

    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=local_rank,
        use_orig_params=True,       # required for torch.compile
        limit_all_gathers=True,     # limit memory from all-gathers
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad(set_to_none=True)
            loss = model(batch).loss
            loss.backward()
            model.clip_grad_norm_(1.0)  # FSDP-aware grad clipping
            optimizer.step()

    dist.destroy_process_group()

Launch:

torchrun --nproc_per_node=4 train.py
# Multi-node:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
  --master_addr=10.0.0.1 --master_port=29500 train.py

Wrap Policies

Transformer Auto-Wrap

Wraps each transformer layer as a separate FSDP unit — the standard for LLMs:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# For HuggingFace models, import the layer class:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

Size-Based Wrap

Wraps modules exceeding a parameter threshold:

auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=1_000_000,  # 1M params
)

Custom Wrap Policy

from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy

def custom_policy(module, recurse, **kwargs):
    if recurse:
        return True  # always recurse into children
    # Wrap specific module types
    return isinstance(module, (TransformerBlock, nn.Embedding))

auto_wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=custom_policy)

Mixed Precision

# bf16 — recommended for A100/H100
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,    # parameters stored in bf16
    reduce_dtype=torch.bfloat16,   # gradient reduction in bf16
    buffer_dtype=torch.bfloat16,   # buffers (e.g., BatchNorm) in bf16
)

# fp16 — for older GPUs (V100/T4), needs loss scaling
mp_policy = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16,
)

# Keep some ops in fp32 for stability (e.g., loss computation)
# FSDP handles this via param_dtype — the forward pass upcasts as needed

Activation Checkpointing

Trade compute for memory — recompute activations during backward instead of storing them:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.checkpoint import checkpoint

# Option 1: Apply to the FSDP-wrapped model
from torch.distributed.fsdp import apply_activation_checkpointing
import functools

apply_activation_checkpointing(
    model,
    checkpoint_wrapper_fn=functools.partial(
        checkpoint_wrapper,
        checkpoint_fn=checkpoint,
    ),
    check_fn=lambda submodule: isinstance(submodule, TransformerBlock),
)

# Option 2: With HuggingFace — just enable in TrainingArguments
# gradient_checkpointing=True (see HF integration below)

Checkpointing

Full State Dict (For Inference / Single-GPU Loading)

from torch.distributed.fsdp import FullStateDictConfig, StateDictType

# Save
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        torch.save(state_dict, "model.pt")

# Load (on any device)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))

Sharded State Dict (For Resuming Training)

Faster save/load — each rank saves its own shard:

from torch.distributed.fsdp import ShardedStateDictConfig, StateDictType
from torch.distributed.checkpoint import save, load

# Save
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
    save(state_dict, checkpoint_id="checkpoint-epoch-1")

# Load
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
    load(state_dict, checkpoint_id="checkpoint-epoch-1")
    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dict(state_dict["optimizer"])

CPU Offloading

Offload parameters and gradients to CPU when not in use:

model = FSDP(
    model,
    cpu_offload=CPUOffload(offload_params=True),
    # Note: significantly slower but enables very large models on limited GPUs
)

HuggingFace Integration

With Trainer + Accelerate

The easiest way to use FSDP with HuggingFace models:

# fsdp_config.yaml (accelerate config)
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_use_orig_params: true
mixed_precision: bf16
num_machines: 1
num_processes: 4
training_args = TrainingArguments(
    output_dir="./results",
    fsdp="full_shard auto_wrap",
    fsdp_config="fsdp_config.yaml",
    bf16=True,
    gradient_checkpointing=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    ...
)
accelerate launch --config_file fsdp_config.yaml train.py
# Or directly with torchrun (Trainer auto-detects FSDP from args)
torchrun --nproc_per_node=4 train.py

With Accelerate (Manual)

from accelerate import Accelerator, FullyShardedDataParallelPlugin
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision

fsdp_plugin = FullyShardedDataParallelPlugin(
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision_policy=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
    ),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for batch in dataloader:
    loss = model(**batch).loss
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

FSDP2 (torch.distributed._composable.fsdp)

PyTorch 2.4+ introduces FSDP2 — a composable, per-parameter sharding API. FSDP2 is the recommended path for new projects.

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16)

# Apply bottom-up: layers first, then root
for layer in model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)  # groups remaining params (embeddings, output)

# Optimizer must use DTensor params
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

Key differences from FSDP1:

  • Per-parameter dim-0 sharding via DTensor (vs flat-parameter concatenation) — more intuitive, relaxes frozen parameter constraints
  • No wrapperfully_shard modifies module in-place, unions type with FSDPModule (exposes .reshard(), .unshard())
  • FQNs preservedstate_dict() keys unchanged, enabling seamless checkpoint compatibility
  • No full state dict API — use DTensor.full_tensor() or torch.distributed.checkpoint for conversion
  • Communication-free sharded state dicts — no all-gathers needed (FSDP1 required them)
  • Better torch.compile integration — composable with TP, CP, etc.
  • Used by TorchTitan and Megatron-LM Bridge

FSDP2 reshard_after_forward

Control memory vs compute tradeoff per module:

# True = free params after forward (default, saves memory)
fully_shard(layer, reshard_after_forward=True)

# False = keep params unsharded (uses more memory, avoids re-allgather in backward)
fully_shard(layer, reshard_after_forward=False)

# int = reshard to a larger world size (partial sharding)
fully_shard(layer, reshard_after_forward=2)

Tensor Parallel + FSDP (2D Parallelism)

Combine FSDP (data parallel) with TP (model parallel) for very large models:

from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
from torch.distributed._composable.fsdp import fully_shard

# 1. Apply tensor parallelism within each node
parallelize_module(model, tp_mesh, {
    "attention.q_proj": ColwiseParallel(),
    "attention.v_proj": ColwiseParallel(),
    "attention.o_proj": RowwiseParallel(),
    "mlp.gate_proj": ColwiseParallel(),
    "mlp.down_proj": RowwiseParallel(),
})

# 2. Apply FSDP across nodes
for layer in model.layers:
    fully_shard(layer, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)

torch.compile with FSDP

# use_orig_params=True is required
model = FSDP(model, use_orig_params=True, ...)

# Compile after wrapping
model = torch.compile(model)

HYBRID_SHARD (Multi-Node)

Full shard within a node, replicate across nodes — reduces inter-node communication:

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
    auto_wrap_policy=auto_wrap_policy,
    device_id=local_rank,
)

Best when: intra-node bandwidth >> inter-node bandwidth (e.g., NVLink within, ethernet across).

Debugging

See references/troubleshooting.md for:

  • FSDP hangs and deadlocks
  • OOM despite sharding
  • Checkpoint save/load issues
  • Mixed precision instability
  • torch.compile incompatibilities

References

Cross-References

  • pytorch — PyTorch distributed training fundamentals
  • ray-train — FSDP integration with Ray Train
  • megatron-lm — Alternative: Megatron-LM for very large models
  • torch-compile — torch.compile + FSDP integration
  • aws-efa — EFA networking for multi-node FSDP

Reference

Install

Download ZIP
Requires askill CLI v1.0+

AI Quality Score

84/100Analyzed 2/22/2026

Comprehensive technical reference for PyTorch FSDP covering both FSDP1 (wrapper-based) and FSDP2 (composable) APIs. Includes sharding strategies, mixed precision policies, wrap policies, activation checkpointing, checkpointing (full/sharded), CPU offloading, HuggingFace integration, and tensor parallelism combination. Well-structured with tables, code examples, and clear explanations of core concepts. Actionable for practitioners with copy-pasteable code and launch commands. Not internal-only - generic PyTorch reference applicable across projects.

95
85
75
82
82

Metadata

Licenseunknown
Version-
Updated3/7/2026
Publishertylertitsworth

Tags

api